warp-lang 1.4.2__py3-none-win_amd64.whl → 1.5.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/native/tile.h ADDED
@@ -0,0 +1,1854 @@
1
+ /** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ * and proprietary rights in and to this software, related documentation
4
+ * and any modifications thereto. Any use, reproduction, disclosure or
5
+ * distribution of this software and related documentation without an express
6
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include "builtin.h"
12
+
13
+ #if !defined(__CUDA_ARCH__)
14
+ #define WP_TILE_SHARED static
15
+ #define WP_TILE_SYNC void
16
+ #else
17
+ #define WP_TILE_SHARED __shared__
18
+ #define WP_TILE_SYNC __syncthreads
19
+ #endif
20
+
21
+ #if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
22
+ #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
23
+ #define WP_PRAGMA_UNROLL _Pragma("unroll")
24
+ #define WP_PRAGMA_NO_UNROLL _Pragma("unroll 1")
25
+ #else
26
+ #define WP_PRAGMA_UNROLL #pragma unroll
27
+ #define WP_PRAGMA_NO_UNROLL #pragma unroll 1
28
+ #endif
29
+
30
+ #else
31
+
32
+ #define WP_PRAGMA_UNROLL
33
+ #define WP_PRAGMA_NO_UNROLL
34
+
35
+ #endif
36
+
37
+ #define WP_USE_ASYNC_PIPELINE 0
38
+ #if WP_USE_ASYNC_PIPELINE
39
+ #include "cuda_pipeline_primitives.h"
40
+ #endif // WP_USE_ASYNC_PIPELINE
41
+
42
+ #define WP_USE_REGISTER_GEMM 0
43
+
44
+ /* Tile Expressions
45
+
46
+ [ ] Tiles
47
+ [x] Register, Shared, Global
48
+ [ ] Layouts
49
+ [x] Simple
50
+ [ ] Cute
51
+ [x] Remove Alloc type from tile_shared_t
52
+ [x] wp.launch_tiled() helper
53
+ [ ] Creation
54
+ [x] zeros
55
+ [x] ones
56
+ [x] arange
57
+ [x] tile()
58
+ [x] untile()
59
+ [ ] fromfunction()
60
+ [ ] explicit storage
61
+ [ ] Load/Store
62
+ [ ] 1D load/store variants
63
+ [ ] max_coord option for non-aligned loads
64
+ [ ] Indexed load
65
+ [x] wp.tile_atomic_add()
66
+ [ ] Maps
67
+ [x] Support user functions
68
+ [x] Support built-in functions
69
+ [ ] Support for lambda functions
70
+ [ ] Infer tile_map() output from operator type (e.g.: dot for each element)
71
+ [ ] Reductions
72
+ [x] Sum
73
+ [x] Forward
74
+ [x] Reverse
75
+ [x] Min
76
+ [x] Max
77
+ [x] Custom
78
+ [x] MatMul
79
+ [x] Forward
80
+ [x] Reverse
81
+ [ ] Operators
82
+ [ ] +, -, *, /, @?
83
+ [ ] += for matmul, e.g.: c += a@b, or c = a@b
84
+ [ ] Reshape
85
+ [ ] Broadcasting
86
+ [ ] Transpose
87
+ [x] Shared
88
+ [ ] Register
89
+ [ ] Slice
90
+ [ ] Runtime
91
+ [x] Compile-time block dimensions
92
+ [x] Switch between SIMT / Tile based execution if `block_dim` not provided to wp.launch()
93
+ [ ] Examples
94
+ [ ] Point registration
95
+ [ ] GEMM
96
+ [ ] MLP
97
+ [ ] LayerNorm
98
+ [ ] SoftMax
99
+ [ ] GEMM
100
+ [ ] warp.sim (CRBA)
101
+ [ ] Batched MLP
102
+ [ ] Layer norm
103
+ [ ] FNO + Burgers equation
104
+ [ ] Stochastic financial modeling
105
+ [ ] Convolution: https://github.com/NVIDIA/MinkowskiEngine/blob/master/src/convolution_kernel.cu#L123
106
+ [ ] MeshCNN (Modulus, Oliver)
107
+ [ ] BioNemo (Ali)
108
+ [ ] Skinning (David/Or/Vismay)
109
+ [ ] warp.sim (VBD)
110
+ [ ] Error checking
111
+ [ ] Ensure functions passed to tile_map() are compatible with tile type
112
+ [ ] Ensure that args passed to tile ops are compatible
113
+ [ ] Ensure tile load/store operations don't go out of bounds of arrays in debug mode
114
+
115
+ */
116
+
117
+ /*
118
+ Notes on shared memory synchronization
119
+ ======================================
120
+
121
+ Currently operations that write to shared memory tiles (e.g.: tile_load())
122
+ must synchronize before they return through WP_TILE_SYNC(), this
123
+ ensures subsequent read operations from the tile do not cause a race condition.
124
+
125
+ For tile_shared_t adjoints, the gradient accumulation is done through shared
126
+ memory atomics, i.e.: atomic_add(), since for broadcast tiles multiple threads
127
+ may map to the same location. Synchronization is still required after these
128
+ updates, since subsequent operations e.g.: adj_tile_load() will store the
129
+ gradients to memory, and all updates must be visible at that point, e.g.:
130
+
131
+ a = wp.tile_load(...)
132
+ b = wp.tile_load(...)
133
+ c = wp.tile_matmul(a, b)
134
+ wp.tile_store(c)
135
+
136
+ // loads incoming adjoints from global -> shared
137
+ wp.adj_tile_store(c, adj_c)
138
+ // consumes adj_c, requires synchronization
139
+ wp.adj_tile_matmul(a, b, adj_a, adj_b, adj_c)
140
+ // consumes adj_b, requires synchronization
141
+ wp.adj_tile_load(..., adj_b)
142
+ // consumes adj_b, requires synchronization
143
+ wp.adj_tile_load(..., adj_a)
144
+
145
+ Generally synchronization to adjoint tiles will happen through the
146
+ tile_shared_t::add() and tile_shared_t::assign() function automatically,
147
+ but in some cases e.g.: tile_matmul() it is done manually.
148
+
149
+ The current synchronization strategy is conservative, and can lead to more
150
+ synchronization than necessary. A more sophisticated strategy would be
151
+ to track the 'dirty' state of shared tiles, and synchronize only when
152
+ necessary. In addition, custom synchronization for e.g.: tile_load()
153
+ operations could be added through a SyncProvider template parameter on
154
+ the tile_shared_t type, for example to support barrier synchronization
155
+ for asynchronous global to shared loads.
156
+ */
157
+
158
+ namespace wp
159
+ {
160
+
161
+ // Primary template
162
+ template <typename T, typename U>
163
+ struct is_same {
164
+ static constexpr bool value = false;
165
+ };
166
+
167
+ // Specialization for the case when T and U are the same type
168
+ template <typename T>
169
+ struct is_same<T, T> {
170
+ static constexpr bool value = true;
171
+ };
172
+
173
+
174
+ template <typename Tile>
175
+ constexpr int tile_size(Tile& t) { return Tile::M*Tile::N; }
176
+
177
+ constexpr int tile_regcount(int m, int n) {
178
+ return (m*n + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
179
+ }
180
+
181
+ struct coord_t
182
+ {
183
+ int i;
184
+ int j;
185
+ };
186
+
187
+
188
+ // represents a tile stored in global memory with dynamic strides
189
+ // only used to represent the source for tile loads to register/shared
190
+ template <typename T>
191
+ struct tile_global_t
192
+ {
193
+ using Type = T;
194
+
195
+ array_t<T> data;
196
+ int x;
197
+ int y;
198
+
199
+ tile_global_t(array_t<T>& a, int x, int y) : data(a), x(x), y(y)
200
+ {
201
+ }
202
+ };
203
+
204
+ // represents a tile stored in registers across a block
205
+ template <typename T, int M_, int N_>
206
+ struct tile_register_t
207
+ {
208
+ using Type = T;
209
+ static constexpr int M = M_;
210
+ static constexpr int N = N_;
211
+ static constexpr int Size = M*N;
212
+
213
+ static constexpr int NumRegs = tile_regcount(M, N);
214
+
215
+ static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
216
+
217
+ T data[NumRegs];
218
+
219
+ inline CUDA_CALLABLE tile_register_t(T value=T(0.0))
220
+ {
221
+ // zero-initialize by default necessary for tile adjoints
222
+ // need to check if this results in worse codegen
223
+ // than doing adj_var = tile_zeros() explicitly
224
+ // in backwards pass and letting default constructor
225
+ // avoid initialization
226
+
227
+ for (int i=0; i < NumRegs; ++i)
228
+ data[i] = value;
229
+ }
230
+
231
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T>& t)
232
+ {
233
+ if (t.data.ndim == 1)
234
+ copy_from_global(t.data, t.x); // 1d load
235
+ else
236
+ copy_from_global(t.data, t.x, t.y); // 2d load
237
+
238
+ return *this;
239
+
240
+ }
241
+
242
+ // define the += operator which is used during backward pass codegen
243
+ // when returning a register tile from a user defined function
244
+ inline CUDA_CALLABLE auto& operator += (tile_register_t<T, M, N>& rhs)
245
+ {
246
+ this->grad_add(rhs);
247
+ return *this;
248
+ }
249
+
250
+ inline CUDA_CALLABLE T& operator()(int index)
251
+ {
252
+ assert(index < NumRegs);
253
+ return data[index];
254
+ }
255
+
256
+ inline CUDA_CALLABLE const T& operator()(int index) const
257
+ {
258
+ assert(index < NumRegs);
259
+ return data[index];
260
+ }
261
+
262
+
263
+ // compute linear tile index from a local register index
264
+ inline CUDA_CALLABLE int index(int reg) const
265
+ {
266
+ return threadIdx.x + reg*WP_TILE_BLOCK_DIM;
267
+ }
268
+
269
+ // compute tile coordinate from linear index
270
+ inline CUDA_CALLABLE coord_t coord(int index) const
271
+ {
272
+ return {index/N, index%N};
273
+ }
274
+
275
+ // Returns the number of valid registers for this tile
276
+ // i.e.: how many registers map to a valid coordinate.
277
+ // When a tile's size is not aligned to the block dimension
278
+ // some of the trailing registers may lie outside the valid range
279
+ inline CUDA_CALLABLE int valid() const
280
+ {
281
+ return (Size - threadIdx.x)/WP_TILE_BLOCK_DIM;
282
+ }
283
+
284
+ inline CUDA_CALLABLE void assign(const tile_register_t<T, M, N>& tile)
285
+ {
286
+ for (int i=0; i < NumRegs; ++i)
287
+ data[i] = tile.data[i];
288
+ }
289
+
290
+ inline CUDA_CALLABLE void zero()
291
+ {
292
+ for (int i=0; i < NumRegs; ++i)
293
+ data[i] = T(0);
294
+ }
295
+
296
+ // extract a single tile element to a native type
297
+ inline CUDA_CALLABLE Type extract(int i, int j)
298
+ {
299
+ // map from logical coords (i, j) -> (thread, reg)
300
+ const int linear = i*N + j;
301
+
302
+ const int thread = linear/NumRegs;
303
+ const int reg = linear%NumRegs;
304
+
305
+ WP_TILE_SHARED Type scratch;
306
+
307
+ // ensure any previously scheduled threads have finished reading from scratch
308
+ WP_TILE_SYNC();
309
+
310
+ if (threadIdx.x == thread)
311
+ {
312
+ scratch = data[reg];
313
+ }
314
+
315
+ // ensure extraction thread has updated smem
316
+ WP_TILE_SYNC();
317
+
318
+ return scratch;
319
+ }
320
+
321
+
322
+ // backward version of scalar extract
323
+ inline CUDA_CALLABLE void adj_extract(int i, int j, Type adj_ret)
324
+ {
325
+ // map from logical coords (i, j) -> (thread, reg)
326
+ const int linear = i*N + j;
327
+
328
+ const int thread = linear/NumRegs;
329
+ const int reg = linear%NumRegs;
330
+
331
+ if (threadIdx.x == thread)
332
+ {
333
+ data[reg] += adj_ret;
334
+ }
335
+ }
336
+
337
+ inline CUDA_CALLABLE void print() const;
338
+
339
+
340
+ // return the in-register version of this tile (nop)
341
+ inline CUDA_CALLABLE auto& copy_to_register()
342
+ {
343
+ return *this;
344
+ }
345
+
346
+ inline CUDA_CALLABLE const auto& copy_to_register() const
347
+ {
348
+ return *this;
349
+ }
350
+
351
+ // in-place gradient zero
352
+ inline CUDA_CALLABLE void grad_zero()
353
+ {
354
+ zero();
355
+ }
356
+
357
+ // accumulate gradients onto this tile
358
+ inline CUDA_CALLABLE void grad_add(const tile_register_t<T, M, N>& tile)
359
+ {
360
+ for (int i=0; i < NumRegs; ++i)
361
+ data[i] += tile.data[i];
362
+ }
363
+
364
+ // copy shared tile to register
365
+ inline CUDA_CALLABLE auto& grad_to_register()
366
+ {
367
+ return *this;
368
+ }
369
+
370
+ void copy_to_global(array_t<T> dest, int x)
371
+ {
372
+ assert(dest.ndim == 1);
373
+
374
+ const int tile_i = x*N;
375
+
376
+ WP_PRAGMA_UNROLL
377
+ for (int i=0; i < NumRegs; ++i)
378
+ {
379
+ // handle case where tile size is not
380
+ // aligned to block dimensions
381
+ int linear = index(i);
382
+ if (!Aligned && linear >= Size)
383
+ break;
384
+
385
+ wp::index(dest, tile_i + linear) = data[i];
386
+ }
387
+ }
388
+
389
+ void copy_to_global(array_t<T> dest, int x, int y)
390
+ {
391
+ assert(dest.ndim == 2);
392
+
393
+ const int tile_i = x*M;
394
+ const int tile_j = y*N;
395
+
396
+ // wp.array() indexing generates poor code due to char* casting
397
+ // here we unroll some of the ops, note this assumes byte strides are
398
+ // aligned to the element size
399
+ T* ptr = &wp::index(dest, tile_i, tile_j);
400
+ const int stride_i = dest.strides[0]/sizeof(T);
401
+ const int stride_j = dest.strides[1]/sizeof(T);
402
+
403
+ WP_PRAGMA_UNROLL
404
+ for (int i=0; i < NumRegs; ++i)
405
+ {
406
+ // handle case where tile size is not
407
+ // aligned to block dimensions
408
+ int linear = index(i);
409
+ if (!Aligned && linear >= Size)
410
+ break;
411
+
412
+ coord_t c = coord(linear);
413
+ ptr[c.i*stride_i + c.j*stride_j] = data[i];
414
+ }
415
+ }
416
+
417
+ inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x)
418
+ {
419
+ // todo: use async pipelines or TMA here
420
+ const int tile_i = x*N;
421
+
422
+ WP_PRAGMA_UNROLL
423
+ for (int i=0; i < NumRegs; ++i)
424
+ {
425
+ int linear = index(i);
426
+ if (!Aligned && linear >= Size)
427
+ break;
428
+
429
+ data[i] = wp::index(src, tile_i + linear);
430
+ }
431
+ }
432
+
433
+ inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x, int y)
434
+ {
435
+ // todo: use async pipelines or TMA here
436
+ const int tile_i = x*M;
437
+ const int tile_j = y*N;
438
+
439
+ // wp.array() indexing generates poor code due to char* casting
440
+ // here we unroll some of the ops, note this assumes array byte strides are
441
+ // aligned to the element size
442
+ const T* ptr = &wp::index(src, tile_i, tile_j);
443
+
444
+ assert(src.strides[0]%sizeof(T) == 0);
445
+ assert(src.strides[1]%sizeof(T) == 0);
446
+
447
+ const int stride_i = src.strides[0]/sizeof(T);
448
+ const int stride_j = src.strides[1]/sizeof(T);
449
+
450
+ WP_PRAGMA_UNROLL
451
+ for (int i=0; i < NumRegs; ++i)
452
+ {
453
+ int linear = index(i);
454
+ if (!Aligned && linear >= Size)
455
+ break;
456
+
457
+ coord_t c = coord(linear);
458
+ data[i] = ptr[c.i*stride_i + c.j*stride_j];
459
+ }
460
+ }
461
+ };
462
+
463
+ // helper to allocate a register tile like another tile
464
+ template<typename Tile>
465
+ auto tile_register_like()
466
+ {
467
+ using T = typename Tile::Type;
468
+
469
+ return tile_register_t<T, Tile::M, Tile::N>(T(0.0));
470
+ }
471
+
472
+ inline CUDA_CALLABLE int tile_align(int num_bytes)
473
+ {
474
+ // note this much match value in Python types.py
475
+ const int alignment = 16;
476
+
477
+ return ((num_bytes + alignment - 1) / alignment) * alignment;
478
+ }
479
+
480
+ inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
481
+ {
482
+ // we maintain a per-thread offset into dynamic
483
+ // shared memory that allows us to keep track of
484
+ // current use across dynamic function calls
485
+ __shared__ int smem_base[WP_TILE_BLOCK_DIM];
486
+
487
+ if (init)
488
+ {
489
+ smem_base[threadIdx.x] = 0;
490
+ return NULL;
491
+ }
492
+ else
493
+ {
494
+ const int offset = smem_base[threadIdx.x];
495
+
496
+ // one entry per-thread so no need for synchronization
497
+ smem_base[threadIdx.x] += tile_align(num_bytes);
498
+
499
+ extern __shared__ char dynamic_smem_base[];
500
+ return &(dynamic_smem_base[offset]);
501
+ }
502
+ }
503
+
504
+
505
+
506
+ template <typename T, int M_, int N_, int StrideM_=N_, int StrideN_=1, bool Owner_=true>
507
+ struct tile_shared_t
508
+ {
509
+ using Type = T;
510
+ static constexpr int M = M_;
511
+ static constexpr int N = N_;
512
+ static constexpr int Size = M*N;
513
+
514
+ static constexpr int StrideM = StrideM_;
515
+ static constexpr int StrideN = StrideN_;
516
+
517
+ static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
518
+ static constexpr bool Unique = (StrideM >= N) && (StrideN >= 1);
519
+ static constexpr bool Owner = Owner_;
520
+
521
+ struct Storage
522
+ {
523
+ T* ptr;
524
+
525
+ Storage(T* p) : ptr(p) {}
526
+
527
+ inline CUDA_CALLABLE T& operator()(int i, int j)
528
+ {
529
+ assert(i < M);
530
+ assert(j < N);
531
+
532
+ return ptr[i*StrideM + j*StrideN];
533
+ }
534
+
535
+ inline CUDA_CALLABLE const T& operator()(int i, int j) const
536
+ {
537
+ assert(i < M);
538
+ assert(j < N);
539
+
540
+ return ptr[i*StrideM + j*StrideN];
541
+ }
542
+
543
+ inline CUDA_CALLABLE T& operator()(int index)
544
+ {
545
+ assert(index < M*N);
546
+
547
+ // unravel
548
+ int i = index/N;
549
+ int j = index%N;
550
+
551
+ return (*this)(i,j);
552
+ }
553
+
554
+ inline CUDA_CALLABLE const T& operator()(int index) const
555
+ {
556
+ assert(index < M*N);
557
+
558
+ // unravel
559
+ int i = index/N;
560
+ int j = index%N;
561
+
562
+ return (*this)(i,j);
563
+ }
564
+ };
565
+
566
+ Storage data;
567
+ Storage grad;
568
+
569
+ // default initialization (non-initialized)
570
+ inline CUDA_CALLABLE tile_shared_t() : data(NULL), grad(NULL)
571
+ {
572
+ }
573
+
574
+ // initialize from an existing tile's memory
575
+ inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=NULL) : data(data), grad(grad)
576
+ {
577
+ }
578
+
579
+ inline CUDA_CALLABLE ~tile_shared_t()
580
+ {
581
+ if (Owner)
582
+ {
583
+ // update our per-thread shared memory allocator
584
+ if (data.ptr)
585
+ tile_alloc_shared(-M*N*int(sizeof(T)));
586
+
587
+ if (grad.ptr)
588
+ tile_alloc_shared(-M*N*int(sizeof(T)));
589
+ }
590
+ }
591
+
592
+ // assign from a register tile
593
+ template <typename Tile>
594
+ inline CUDA_CALLABLE auto& operator=(const Tile& t)
595
+ {
596
+ assign(t);
597
+ return *this;
598
+ }
599
+
600
+ // construct from another shared tile, this constructor
601
+ // is invoked for reshape operations like `wp.tile_transpose()`
602
+ template <typename OtherT, int OtherM, int OtherN, int OtherStrideM, int OtherStrideN>
603
+ inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherM, OtherN, OtherStrideM, OtherStrideN>& rhs)
604
+ {
605
+ using OtherTile = tile_shared_t<OtherT, OtherM, OtherN, OtherStrideM, OtherStrideN>;
606
+
607
+ // check dimensions are compatible
608
+ static_assert(Size == OtherTile::Size);
609
+
610
+ // alias tile directly
611
+ data = rhs.data;
612
+ grad = rhs.grad;
613
+
614
+ return *this;
615
+ }
616
+
617
+ // assign from a global tile (load)
618
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T>& t)
619
+ {
620
+ if (t.data.ndim == 1)
621
+ copy_from_global(t.data, t.x); // 1d load
622
+ else
623
+ copy_from_global(t.data, t.x, t.y); // 2d load
624
+
625
+ // synchronization happens in copy functions above
626
+
627
+ return *this;
628
+ }
629
+
630
+ // assign from a constant value
631
+ inline CUDA_CALLABLE auto& operator=(const T& x)
632
+ {
633
+ for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
634
+ data(i) = x;
635
+
636
+ WP_TILE_SYNC();
637
+ return *this;
638
+ }
639
+
640
+
641
+ // compute tile coordinate from linear index
642
+ inline CUDA_CALLABLE coord_t coord(int index) const
643
+ {
644
+ return {index/N, index%N};
645
+ }
646
+
647
+ // in-place zero
648
+ inline CUDA_CALLABLE void zero()
649
+ {
650
+ for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
651
+ data(i) = T(0);
652
+
653
+ WP_TILE_SYNC();
654
+ }
655
+
656
+ // extract a single tile element to a native type
657
+ inline CUDA_CALLABLE Type extract(int i, int j)
658
+ {
659
+ return data(i, j);
660
+ }
661
+
662
+ // backward of scalar extraction
663
+ inline CUDA_CALLABLE void adj_extract(int i, int j, Type adj_ret)
664
+ {
665
+ if (threadIdx.x == 0)
666
+ data(i, j) += adj_ret;
667
+
668
+ WP_TILE_SYNC();
669
+ }
670
+
671
+
672
+ // copy register tile to shared
673
+ inline CUDA_CALLABLE void assign(const tile_register_t<T, M, N>& tile)
674
+ {
675
+ WP_PRAGMA_UNROLL
676
+ for (int i=0; i < tile.NumRegs; ++i)
677
+ {
678
+ const int linear = tile.index(i);
679
+
680
+ // handle case where tile size is not
681
+ // aligned to block dimensions
682
+ if (!Aligned && linear >= Size)
683
+ break;
684
+
685
+ data(linear) = tile.data[i];
686
+ }
687
+
688
+ WP_TILE_SYNC();
689
+ }
690
+
691
+ // in-place gradient zero
692
+ inline CUDA_CALLABLE void grad_zero()
693
+ {
694
+ // todo: make this subtile (stride aware)
695
+ for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
696
+ grad(i) = T(0);
697
+
698
+ WP_TILE_SYNC();
699
+ }
700
+
701
+
702
+ // accumulate gradients onto this tile
703
+ inline CUDA_CALLABLE void grad_add(const tile_register_t<T, M, N>& tile)
704
+ {
705
+ WP_PRAGMA_UNROLL
706
+ for (int i=0; i < tile.NumRegs; ++i)
707
+ {
708
+ const int linear = tile.index(i);
709
+
710
+ // handle case where tile size is not
711
+ // aligned to block dimensions
712
+ if (!Aligned && linear >= Size)
713
+ break;
714
+
715
+ if (Unique)
716
+ grad(linear) += tile.data[i];
717
+ else
718
+ // use shared memory atomics to accumulate gradients
719
+ // since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
720
+ // may map to a single location in shared memory
721
+ atomic_add(&grad(linear), tile.data[i]);
722
+
723
+ }
724
+
725
+ WP_TILE_SYNC();
726
+ }
727
+
728
+ // copy shared tile to register
729
+ inline CUDA_CALLABLE tile_register_t<T, M, N> grad_to_register()
730
+ {
731
+ tile_register_t<T, M, N> out;
732
+
733
+ WP_PRAGMA_UNROLL
734
+ for (int i=0; i < out.NumRegs; ++i)
735
+ {
736
+ const int linear = out.index(i);
737
+
738
+ // handle case where tile size is not
739
+ // aligned to block dimensions
740
+ if (!Aligned && linear >= Size)
741
+ break;
742
+
743
+ out(i) = grad(linear);
744
+ }
745
+
746
+ return out;
747
+ }
748
+
749
+ inline CUDA_CALLABLE void print() const
750
+ {
751
+ if (threadIdx.x == 0)
752
+ {
753
+ printf("tile(m=%d, n=%d, storage=shared) = [", M, N);
754
+ for (int i=0; i < M; ++i)
755
+ {
756
+ printf("%*s[", i>0, "");
757
+ for (int j=0; j < N; ++j)
758
+ {
759
+ printf("%g ", double(data(i, j)));
760
+ }
761
+
762
+ if (i == M-1)
763
+ printf("]]\n");
764
+ else
765
+ printf("]\n");
766
+ }
767
+ }
768
+ }
769
+
770
+ // copy shared tile to register
771
+ inline CUDA_CALLABLE tile_register_t<T, M, N> copy_to_register() const
772
+ {
773
+ tile_register_t<T, M, N> out;
774
+
775
+ WP_PRAGMA_UNROLL
776
+ for (int i=0; i < out.NumRegs; ++i)
777
+ {
778
+ const int linear = out.index(i);
779
+
780
+ // handle case where tile size is not
781
+ // aligned to block dimensions
782
+ if (!Aligned && linear >= Size)
783
+ break;
784
+
785
+ out(i) = data(linear);
786
+ }
787
+
788
+ return out;
789
+ }
790
+
791
+ inline CUDA_CALLABLE void copy_to_global(array_t<T> dest, int x) const
792
+ {
793
+ assert(dest.ndim == 1);
794
+
795
+ // todo: use TMA here
796
+ const int tile_i = x*N;
797
+
798
+ WP_PRAGMA_UNROLL
799
+ for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
800
+ {
801
+ wp::index(dest, tile_i + i) = data(i);
802
+ }
803
+ }
804
+
805
+ inline CUDA_CALLABLE void copy_to_global(array_t<T> dest, int x, int y)
806
+ {
807
+ // todo: use TMA here
808
+ const int tile_i = x*M;
809
+ const int tile_j = y*N;
810
+
811
+ // check each row is contiguous and 128bit aligned
812
+ if (StrideN == 1 && dest.strides[1] == sizeof(T) && (N*sizeof(T))%sizeof(float4) == 0)
813
+ {
814
+ constexpr int num_rows = M;
815
+ constexpr int num_cols = (N*sizeof(T))/sizeof(float4);
816
+
817
+ tile_shared_t<float4, num_rows, num_cols> src128((float4*)data.ptr);
818
+
819
+ // alias of shared tile with 128bit type
820
+ float4* ptr = (float4*)&wp::index(dest, tile_i, tile_j);
821
+
822
+ assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
823
+ assert(((uint64_t)(ptr))%sizeof(float4) == 0);
824
+
825
+ const int stride_i = dest.strides[0]/sizeof(float4);
826
+ const int stride_j = 1;
827
+
828
+ WP_PRAGMA_UNROLL
829
+ for (int i=threadIdx.x; i < src128.Size; i += WP_TILE_BLOCK_DIM)
830
+ {
831
+ coord_t c = src128.coord(i);
832
+ ptr[c.i*stride_i + c.j*stride_j] = src128.data(i);
833
+ }
834
+ }
835
+ else
836
+ {
837
+ // wp.array() indexing generates poor code due to char* casting
838
+ // here we unroll some of the ops, note this assumes byte strides are
839
+ // aligned to the element size
840
+ T* ptr = &wp::index(dest, tile_i, tile_j);
841
+ const int stride_i = dest.strides[0]/sizeof(T);
842
+ const int stride_j = dest.strides[1]/sizeof(T);
843
+
844
+ WP_PRAGMA_UNROLL
845
+ for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
846
+ {
847
+ coord_t c = coord(i);
848
+ ptr[c.i*stride_i + c.j*stride_j] = data(c.i, c.j);
849
+ }
850
+ }
851
+ }
852
+
853
+ inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x)
854
+ {
855
+ // todo: use async pipelines or TMA here
856
+ const int tile_i = x*N;
857
+
858
+ WP_PRAGMA_UNROLL
859
+ for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
860
+ {
861
+ data(i) = wp::index(src, tile_i + i);
862
+ }
863
+
864
+ WP_TILE_SYNC();
865
+ }
866
+
867
+ inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x, int y)
868
+ {
869
+ // todo: use async pipelines or TMA here
870
+ const int tile_i = x*M;
871
+ const int tile_j = y*N;
872
+
873
+ // check each row is contiguous and 128bit aligned
874
+ if (StrideN == 1 && src.strides[1] == sizeof(T) && (N*sizeof(T))%sizeof(float4) == 0)
875
+ {
876
+ constexpr int num_rows = M;
877
+ constexpr int num_cols = (N*sizeof(T))/sizeof(float4);
878
+
879
+ // alias of shared tile with 128bit type
880
+ tile_shared_t<float4, num_rows, num_cols> dest128((float4*)data.ptr);
881
+
882
+ const float4* ptr = (const float4*)&wp::index(src, tile_i, tile_j);
883
+
884
+ assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
885
+ assert(((uint64_t)(ptr))%sizeof(float4) == 0);
886
+
887
+ const int stride_i = src.strides[0]/sizeof(float4);
888
+ //const int stride_j = 1;
889
+
890
+ WP_PRAGMA_UNROLL
891
+ for (int i=threadIdx.x; i < dest128.Size; i += WP_TILE_BLOCK_DIM)
892
+ {
893
+ coord_t c = dest128.coord(i);
894
+
895
+ #if WP_USE_ASYNC_PIPELINE
896
+ __pipeline_memcpy_async(&dest128.data(i),
897
+ &ptr[c.i*stride_i + c.j],
898
+ sizeof(float4));
899
+ #else
900
+ dest128.data(i) = ptr[c.i*stride_i + c.j];
901
+ #endif // WP_USE_ASYNC_PIPELINE
902
+ }
903
+
904
+ #if WP_USE_ASYNC_PIPELINE
905
+ __pipeline_commit();
906
+ #endif // WP_USE_ASYNC_PIPELINE
907
+
908
+ }
909
+ else
910
+ {
911
+ // wp.array() indexing generates poor code due to char* casting
912
+ // here we unroll some of the ops, note this assumes array byte strides are
913
+ // aligned to the element size
914
+ const T* ptr = &wp::index(src, tile_i, tile_j);
915
+
916
+ assert(src.strides[0]%sizeof(T) == 0);
917
+ assert(src.strides[1]%sizeof(T) == 0);
918
+
919
+ const int stride_i = src.strides[0]/sizeof(T);
920
+ const int stride_j = src.strides[1]/sizeof(T);
921
+
922
+ WP_PRAGMA_UNROLL
923
+ for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
924
+ {
925
+ coord_t c = coord(i);
926
+ data(c.i, c.j) = ptr[c.i*stride_i + c.j*stride_j];
927
+ }
928
+ }
929
+
930
+ #if !WP_USE_ASYNC_PIPELINE
931
+ WP_TILE_SYNC();
932
+ #endif
933
+
934
+ }
935
+ };
936
+
937
+ template <typename T, int M, int N>
938
+ void tile_register_t<T, M, N>::print() const
939
+ {
940
+ // create a temporary shared tile so that
941
+ // we can print it deterministically
942
+ WP_TILE_SHARED T smem[M*N];
943
+
944
+ tile_shared_t<T, M, N> scratch(smem, NULL);
945
+ scratch.assign(*this);
946
+
947
+ WP_TILE_SYNC();
948
+
949
+ if (threadIdx.x == 0)
950
+ {
951
+ printf("tile(m=%d, n=%d, storage=register) = [", M, N);
952
+ for (int i=0; i < M; ++i)
953
+ {
954
+ printf("%*s[", i>0, "");
955
+ for (int j=0; j < N; ++j)
956
+ {
957
+ printf("%g ", double(scratch.data(i, j)));
958
+ }
959
+
960
+ if (i == M-1)
961
+ printf("]]\n");
962
+ else
963
+ printf("]\n");
964
+ }
965
+ }
966
+
967
+ WP_TILE_SYNC();
968
+ }
969
+
970
+ template <typename T, int M, int N>
971
+ inline CUDA_CALLABLE void print(const tile_register_t<T, M, N>& t)
972
+ {
973
+ t.print();
974
+ }
975
+
976
+ template <typename T, int M, int N>
977
+ inline CUDA_CALLABLE void adj_print(const tile_register_t<T, M, N>& t, const tile_register_t<T, M, N>& a)
978
+ {
979
+ a.print();
980
+ }
981
+
982
+ template <typename T, int M, int N, int StrideM, int StrideN, bool Owner>
983
+ inline CUDA_CALLABLE void print(const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& t)
984
+ {
985
+ t.print();
986
+ }
987
+
988
+ template <typename T, int M, int N, int StrideM, int StrideN, bool Owner>
989
+ inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& t, const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& a)
990
+ {
991
+ a.print();
992
+ }
993
+
994
+ // helpers to allocate shared tiles
995
+ template <typename T, int M, int N, bool RequiresGrad>
996
+ inline CUDA_CALLABLE auto tile_alloc_empty()
997
+
998
+ { constexpr int Len = M*N;
999
+ T* data = (T*)tile_alloc_shared(Len*sizeof(T));
1000
+ T* grad = NULL;
1001
+
1002
+ #if FP_CHECK
1003
+
1004
+ for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1005
+ data[i] = T(nanf(""));
1006
+
1007
+ WP_TILE_SYNC();
1008
+
1009
+ #endif // FP_CHECK
1010
+
1011
+
1012
+ if (RequiresGrad)
1013
+ {
1014
+ grad = (T*)tile_alloc_shared(Len*sizeof(T));
1015
+
1016
+ for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1017
+ grad[i] = T(0);
1018
+
1019
+ WP_TILE_SYNC();
1020
+ }
1021
+
1022
+ return tile_shared_t<T, M, N>(data, grad);
1023
+ }
1024
+
1025
+ template <typename T, int M, int N, bool RequiresGrad>
1026
+ inline CUDA_CALLABLE auto tile_alloc_zeros()
1027
+ {
1028
+ // compute the total storage required for the tile (may be different from M*N) for broadcast tiles
1029
+ constexpr int Len = M*N;
1030
+ T* data = (T*)tile_alloc_shared(Len*sizeof(T));
1031
+ T* grad = NULL;
1032
+
1033
+ for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1034
+ data[i] = T(0);
1035
+
1036
+ if (RequiresGrad)
1037
+ {
1038
+ grad = (T*)tile_alloc_shared(Len*sizeof(T));
1039
+
1040
+ for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1041
+ grad[i] = T(0);
1042
+ }
1043
+
1044
+ WP_TILE_SYNC();
1045
+
1046
+ return tile_shared_t<T, M, N, StrideM, StrideN>(data, grad);
1047
+ }
1048
+
1049
+
1050
+ //-----------------------------------------------------------------------------------------------------
1051
+ // High level entry points for each op (correspond to one Warp builtin)
1052
+
1053
+ // construct a tile from a local SIMT value (one per-thread)
1054
+ template <typename T>
1055
+ inline CUDA_CALLABLE auto tile(const T& x)
1056
+ {
1057
+ tile_register_t<T, 1, WP_TILE_BLOCK_DIM> result;
1058
+
1059
+ static_assert(result.NumRegs == 1);
1060
+
1061
+ result.data[0] = x;
1062
+ return result;
1063
+ }
1064
+
1065
+ // overload for constructing a tile from a per-thread vector
1066
+ template <typename T, unsigned Length>
1067
+ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1068
+ {
1069
+ tile_register_t<T, Length, WP_TILE_BLOCK_DIM> result;
1070
+
1071
+ static_assert(result.NumRegs == Length);
1072
+
1073
+ for (int i=0; i < Length; ++i)
1074
+ result.data[i] = x[i];
1075
+
1076
+ return result;
1077
+ }
1078
+
1079
+ // construct a tile from a local SIMT value (one per-thread)
1080
+ template <typename T, typename AdjTile>
1081
+ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1082
+ {
1083
+ static_assert(AdjTile::M == 1);
1084
+ static_assert(AdjTile::N == WP_TILE_BLOCK_DIM);
1085
+
1086
+ auto adj_reg = adj_ret.copy_to_register();
1087
+
1088
+ adj_x += adj_reg.data[0];
1089
+ }
1090
+
1091
+ template <typename T, unsigned Length, typename AdjTile>
1092
+ inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
1093
+ {
1094
+ static_assert(AdjTile::M == Length);
1095
+ static_assert(AdjTile::N == WP_TILE_BLOCK_DIM);
1096
+
1097
+ auto adj_reg = adj_ret.copy_to_register();
1098
+
1099
+ for (int i=0; i < Length; ++i)
1100
+ adj_x[i] += adj_reg.data[i];
1101
+ }
1102
+
1103
+ template <typename Tile>
1104
+ inline CUDA_CALLABLE auto untile(Tile& tile)
1105
+ {
1106
+ // code-gen should have set the tile to
1107
+ // have exactly the block dimension so
1108
+ // there is exactly one value per-thread
1109
+ auto reg = tile.copy_to_register();
1110
+
1111
+ // scalar case
1112
+ if constexpr(Tile::M == 1)
1113
+ {
1114
+ return reg.data[0];
1115
+ }
1116
+
1117
+ // vector case
1118
+ if constexpr(Tile::M > 1)
1119
+ {
1120
+ wp::vec_t<Tile::M, typename Tile::Type> v;
1121
+ for (int i=0; i < Tile::M; ++i)
1122
+ v[i] = reg.data[i];
1123
+
1124
+ return v;
1125
+ }
1126
+ }
1127
+
1128
+ template <typename Tile, typename Value>
1129
+ inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
1130
+ {
1131
+ auto adj = adj_tile.copy_to_register();
1132
+
1133
+ // scalar case
1134
+ if constexpr(Tile::M == 1)
1135
+ {
1136
+ adj.data[0] += adj_ret;
1137
+ }
1138
+
1139
+ // vector case
1140
+ if constexpr(Tile::M > 1)
1141
+ {
1142
+ for (int i=0; i < Tile::M; ++i)
1143
+ adj.data[i] = adj_ret[i];
1144
+ }
1145
+
1146
+ adj_tile.assign(adj);
1147
+ }
1148
+
1149
+ // zero initialized tile
1150
+ template <typename T, int M, int N>
1151
+ inline CUDA_CALLABLE auto tile_zeros()
1152
+ {
1153
+ // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
1154
+ return T(0);
1155
+ }
1156
+
1157
+ // one-initialized tile
1158
+ template <typename T, int M, int N>
1159
+ inline CUDA_CALLABLE auto tile_ones()
1160
+ {
1161
+ // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
1162
+ return T(1);
1163
+ }
1164
+
1165
+ // tile with evenly spaced values
1166
+ template <typename T, int M, int N>
1167
+ inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
1168
+ {
1169
+ tile_register_t<T, M, N> out;
1170
+
1171
+ WP_PRAGMA_UNROLL
1172
+ for (int i=0; i < out.NumRegs; ++i)
1173
+ {
1174
+ const int linear = out.index(i);
1175
+
1176
+ // handle case where tile size is not
1177
+ // aligned to block dimensions
1178
+ if (!out.Aligned && linear >= out.Size)
1179
+ break;
1180
+
1181
+ out.data[i] = start + linear*step;
1182
+ }
1183
+
1184
+ return out;
1185
+ }
1186
+
1187
+ template <typename T, typename AdjTile>
1188
+ inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
1189
+ T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
1190
+
1191
+ // entry point for 1d load
1192
+ template <typename T, int N>
1193
+ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, int x)
1194
+ {
1195
+ return tile_global_t<T>(src, x, 0);
1196
+ }
1197
+
1198
+ // entry point for 2d load
1199
+ template <typename T, int M, int N>
1200
+ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, int x, int y)
1201
+ {
1202
+ return tile_global_t<T>(src, x, y);
1203
+ }
1204
+
1205
+ // entry point for 1d store
1206
+ template <typename T, typename Tile>
1207
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src)
1208
+ {
1209
+ // dispatch to tile type
1210
+ src.copy_to_global(dest, x);
1211
+ }
1212
+
1213
+ // entry point for 2d store
1214
+ template <typename T, typename Tile>
1215
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src)
1216
+ {
1217
+ // dispatch to tile type
1218
+ src.copy_to_global(dest, x, y);
1219
+ }
1220
+
1221
+ template <typename T, typename Tile>
1222
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src)
1223
+ {
1224
+ auto src_reg = src.copy_to_register();
1225
+
1226
+ const int tile_i = x*src_reg.M;
1227
+ const int tile_j = y*src_reg.N;
1228
+
1229
+ tile_register_t<T, src_reg.M, src_reg.N> previous;
1230
+
1231
+ WP_PRAGMA_UNROLL
1232
+ for (int i=0; i < src_reg.NumRegs; ++i)
1233
+ {
1234
+ // handle case where tile size is not
1235
+ // aligned to block dimensions
1236
+ int linear = src_reg.index(i);
1237
+ if (!src_reg.Aligned && linear >= src_reg.Size)
1238
+ break;
1239
+
1240
+ coord_t c = src_reg.coord(linear);
1241
+ previous.data[i] = atomic_add(dest, tile_i + c.i, tile_j + c.j, src_reg.data[i]);
1242
+ }
1243
+
1244
+ return previous;
1245
+ }
1246
+
1247
+
1248
+
1249
+ //-------------------------------------
1250
+ // Adjoints
1251
+
1252
+ template <typename T, typename AdjTile>
1253
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x,
1254
+ array_t<T>& adj_src, int adj_x,
1255
+ AdjTile& adj_ret)
1256
+ {
1257
+ // early out
1258
+ // if (!src.grad)
1259
+ // return;
1260
+
1261
+ auto adj_reg = adj_ret.grad_to_register();
1262
+
1263
+ const int tile_i = x*adj_reg.N;
1264
+
1265
+ // add gradients to src array
1266
+ WP_PRAGMA_UNROLL
1267
+ for (int i=0; i < adj_reg.NumRegs; ++i)
1268
+ {
1269
+ int linear = adj_reg.index(i);
1270
+ if (!adj_reg.Aligned && linear >= adj_reg.Size)
1271
+ break;
1272
+
1273
+ auto grad = adj_reg.data[i];
1274
+
1275
+ if (adj_src.data)
1276
+ adj_atomic_add(&index(adj_src, tile_i + linear), grad);
1277
+ else if (src.grad)
1278
+ adj_atomic_add(&index_grad(src, tile_i + linear), grad);
1279
+ }
1280
+ }
1281
+
1282
+ template <typename T, typename AdjTile>
1283
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y,
1284
+ array_t<T>& adj_src, int adj_x, int adj_y,
1285
+ AdjTile& adj_ret)
1286
+ {
1287
+ // early out
1288
+ // if (!src.grad)
1289
+ // return;
1290
+
1291
+ auto adj_reg = adj_ret.grad_to_register();
1292
+
1293
+ const int tile_i = x*adj_reg.M;
1294
+ const int tile_j = y*adj_reg.N;
1295
+
1296
+ // add gradients to src array
1297
+ WP_PRAGMA_UNROLL
1298
+ for (int i=0; i < adj_reg.NumRegs; ++i)
1299
+ {
1300
+ int linear = adj_reg.index(i);
1301
+ if (!adj_reg.Aligned && linear >= adj_reg.Size)
1302
+ break;
1303
+
1304
+ coord_t coord = adj_reg.coord(linear);
1305
+
1306
+ auto grad = adj_reg.data[i];
1307
+
1308
+ if (adj_src.data)
1309
+ adj_atomic_add(&index(adj_src, tile_i + coord.i, tile_j + coord.j), grad);
1310
+ else if (src.grad)
1311
+ adj_atomic_add(&index_grad(src, tile_i + coord.i, tile_j + coord.j), grad);
1312
+ }
1313
+ }
1314
+
1315
+
1316
+ template <typename T, typename Tile, typename AdjTile>
1317
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t)
1318
+ {
1319
+ // convert to register if necessary
1320
+ tile_register_t<T, AdjTile::M, AdjTile::N> adj_reg;
1321
+
1322
+ const int tile_i = x*adj_reg.N;
1323
+
1324
+ // load gradients from output
1325
+ WP_PRAGMA_UNROLL
1326
+ for (int i=0; i < adj_reg.NumRegs; ++i)
1327
+ {
1328
+ int linear = adj_reg.index(i);
1329
+ if (!adj_reg.Aligned && linear >= adj_reg.Size)
1330
+ break;
1331
+
1332
+ if (adj_dest.data)
1333
+ adj_reg.data[i] = index(adj_dest, tile_i + linear);
1334
+ else if (dest.grad)
1335
+ adj_reg.data[i] = index_grad(dest, tile_i + linear);
1336
+ }
1337
+
1338
+ // store adjoint back to tile
1339
+ adj_t.grad_add(adj_reg);
1340
+ }
1341
+
1342
+ template <typename T, typename Tile, typename AdjTile>
1343
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t)
1344
+ {
1345
+ // allocate register tile to load grads into
1346
+ tile_register_t<T, AdjTile::M, AdjTile::N> adj_reg;
1347
+
1348
+ const int tile_i = x*adj_reg.M;
1349
+ const int tile_j = y*adj_reg.N;
1350
+
1351
+ // load gradients from output
1352
+ WP_PRAGMA_UNROLL
1353
+ for (int i=0; i < adj_reg.NumRegs; ++i)
1354
+ {
1355
+ int linear = adj_reg.index(i);
1356
+ if (!adj_reg.Aligned && linear >= adj_reg.Size)
1357
+ break;
1358
+
1359
+ coord_t coord = adj_reg.coord(linear);
1360
+
1361
+ if (adj_dest.data)
1362
+ adj_reg.data[i] = index(adj_dest, tile_i + coord.i, tile_j + coord.j);
1363
+ else if (dest.grad)
1364
+ adj_reg.data[i] = index_grad(dest, tile_i + coord.i, tile_j + coord.j);
1365
+ }
1366
+
1367
+ // store adjoint back to tile
1368
+ adj_t.grad_add(adj_reg);
1369
+ }
1370
+
1371
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1372
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret)
1373
+ {
1374
+ adj_tile_store(dest, x, y, t, adj_dest, adj_x, adj_y, adj_t);
1375
+ }
1376
+
1377
+
1378
+ // unary map
1379
+ template <typename Tile, typename Fwd>
1380
+ inline CUDA_CALLABLE auto tile_map(Fwd op,
1381
+ Tile &a)
1382
+ {
1383
+ auto out = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
1384
+ auto a_reg = a.copy_to_register();
1385
+
1386
+ WP_PRAGMA_UNROLL
1387
+ for (int i=0; i < out.NumRegs; ++i)
1388
+ {
1389
+ out.data[i] = op(a_reg.data[i]);
1390
+ }
1391
+
1392
+ return out;
1393
+ }
1394
+
1395
+
1396
+ template <typename Tile, typename AdjTile, typename Fwd, typename Adj>
1397
+ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1398
+ Tile& a,
1399
+ Adj adj_op,
1400
+ Tile& adj_a,
1401
+ AdjTile& adj_ret)
1402
+ {
1403
+ auto a_reg = a.copy_to_register();
1404
+ auto adj_a_reg = tile_register_like<Tile>();
1405
+ auto adj_ret_reg = adj_ret.grad_to_register();
1406
+
1407
+ WP_PRAGMA_UNROLL
1408
+ for (int i=0; i < a_reg.NumRegs; ++i)
1409
+ {
1410
+ adj_op(a_reg.data[i], adj_a_reg.data[i], adj_ret_reg.data[i]);
1411
+ }
1412
+
1413
+ // write adjoints back
1414
+ adj_a.grad_add(adj_a_reg);
1415
+ }
1416
+
1417
+ // binary map
1418
+ template <typename TileA, typename TileB, typename Fwd>
1419
+ inline CUDA_CALLABLE auto tile_map(Fwd op,
1420
+ TileA& a,
1421
+ TileB& b)
1422
+ {
1423
+ auto out = tile_register_t<typename TileA::Type, TileA::M, TileA::N>();
1424
+
1425
+ auto a_reg = a.copy_to_register();
1426
+ auto b_reg = b.copy_to_register();
1427
+
1428
+ WP_PRAGMA_UNROLL
1429
+ for (int i=0; i < out.NumRegs; ++i)
1430
+ out.data[i] = op(a_reg.data[i], b_reg.data[i]);
1431
+
1432
+ return out;
1433
+ }
1434
+
1435
+
1436
+ template <typename TileA, typename TileB, typename Fwd, typename Adj, typename AdjTile>
1437
+ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1438
+ TileA &a,
1439
+ TileB &b,
1440
+ Adj adj_op,
1441
+ TileA &adj_a,
1442
+ TileB &adj_b,
1443
+ AdjTile &adj_ret)
1444
+ {
1445
+ auto a_reg = a.copy_to_register();
1446
+ auto b_reg = b.copy_to_register();
1447
+
1448
+ // allocate storage for adjoints
1449
+ auto adj_a_reg = tile_register_like<TileA>();
1450
+ auto adj_b_reg = tile_register_like<TileB>();
1451
+
1452
+ auto adj_ret_reg = adj_ret.grad_to_register();
1453
+
1454
+ WP_PRAGMA_UNROLL
1455
+ for (int i=0; i < a_reg.NumRegs; ++i)
1456
+ {
1457
+ adj_op(a_reg.data[i], b_reg.data[i], adj_a_reg.data[i], adj_b_reg.data[i], adj_ret_reg.data[i]);
1458
+ }
1459
+
1460
+ adj_a.grad_add(adj_a_reg);
1461
+ adj_b.grad_add(adj_b_reg);
1462
+ }
1463
+
1464
+ // wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
1465
+ // this is important because many of the builtin operators don't follow particular conventions on references for
1466
+ // the `adj_ret` parameter, which means it's not possible to figure out the overload we need using simple casting
1467
+ #define tile_unary_map(op, a) tile_map([](auto x) { return op(x);}, a)
1468
+ #define adj_tile_unary_map(op, a, adj_op, adj_a, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
1469
+
1470
+ #define tile_binary_map(op, a, b) tile_map([](auto x, auto y) { return op(x, y);}, a, b)
1471
+ #define adj_tile_binary_map(op, a, b, adj_op, adj_a, adj_b, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
1472
+
1473
+ // -tile (unary neg)
1474
+ template <typename Tile>
1475
+ inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a); }
1476
+
1477
+ template <typename Tile, typename AdjTile>
1478
+ inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, wp::adj_neg, adj_a, adj_ret); }
1479
+
1480
+
1481
+ // tile + tile
1482
+ template <typename TileA, typename TileB>
1483
+ inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
1484
+ {
1485
+ return tile_binary_map(add, a, b);
1486
+ }
1487
+
1488
+ // // tile + tile, we implement this
1489
+ // template <typename TileA, typename TileB>
1490
+ // inline CUDA_CALLABLE auto add(TileA& a, TileB& b)
1491
+ // {
1492
+ // return tile_binary_map(add, a, b);
1493
+ // }
1494
+
1495
+
1496
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1497
+ inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1498
+ {
1499
+ adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
1500
+ }
1501
+
1502
+ // tile*scalar
1503
+ template <typename Tile>
1504
+ inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
1505
+ {
1506
+ // promote scalar to a constant tile
1507
+ auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
1508
+
1509
+ return tile_binary_map(mul, a, s_tile);
1510
+ }
1511
+
1512
+ template <typename Tile, typename AdjTile>
1513
+ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
1514
+ Tile& adj_a, typename Tile::Type& adj_s,
1515
+ AdjTile& adj_c)
1516
+ {
1517
+ auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
1518
+ auto adj_s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
1519
+
1520
+ adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
1521
+
1522
+ for (int i=0; i < adj_s_tile.NumRegs; ++i)
1523
+ {
1524
+ adj_s += adj_s_tile.data[i];
1525
+ }
1526
+ }
1527
+
1528
+
1529
+ // scalar*tile
1530
+ template <typename Tile>
1531
+ inline CUDA_CALLABLE auto tile_mul(const typename Tile::Type& s, Tile& a)
1532
+ {
1533
+ // promote scalar to a constant tile
1534
+ auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
1535
+
1536
+ return tile_binary_map(mul, s_tile, a);
1537
+ }
1538
+
1539
+ template <typename Tile, typename AdjTile>
1540
+ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
1541
+ typename Tile::Type& adj_s, Tile& adj_a,
1542
+ AdjTile& adj_c)
1543
+ {
1544
+ auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
1545
+ auto adj_s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
1546
+
1547
+ adj_tile_binary_map(mul, s_tile, a, adj_mul, adj_s_tile, adj_a, adj_c);
1548
+
1549
+ for (int i=0; i < adj_s_tile.NumRegs; ++i)
1550
+ {
1551
+ adj_s += adj_s_tile.data[i];
1552
+ }
1553
+ }
1554
+
1555
+
1556
+
1557
+ template<typename Tile>
1558
+ typename Tile::Type tile_extract(Tile& t, int i, int j)
1559
+ {
1560
+ assert(i < Tile::M);
1561
+ assert(j < Tile::N);
1562
+
1563
+ return t.extract(i, j);
1564
+ }
1565
+
1566
+ template<typename Tile, typename AdjTile>
1567
+ void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type adj_ret)
1568
+ {
1569
+ assert(i < Tile::M);
1570
+ assert(j < Tile::N);
1571
+
1572
+ adj_t.adj_extract(i, j, adj_ret);
1573
+ }
1574
+
1575
+ namespace partitioned_gemm
1576
+ {
1577
+
1578
+ template <typename T>
1579
+ inline CUDA_CALLABLE const T& index(const T* __restrict__ p, int i, int j, int stride)
1580
+ {
1581
+ return p[i*stride + j];
1582
+ }
1583
+
1584
+ template <typename T>
1585
+ inline CUDA_CALLABLE T& index(T* __restrict__ p, int i, int j, int stride)
1586
+ {
1587
+ return p[i*stride + j];
1588
+ }
1589
+
1590
+ template <int PartitionM, int PartitionN, typename Tile>
1591
+ struct partition_t
1592
+ {
1593
+ static constexpr int M = PartitionM;
1594
+ static constexpr int N = PartitionN;
1595
+ static constexpr int Stride = Tile::N;
1596
+
1597
+ using T = typename Tile::Type;
1598
+
1599
+ inline partition_t(Tile& A)
1600
+ {
1601
+ data = A.data.ptr;
1602
+
1603
+ // todo: do ceil div for non-multiples of M,N
1604
+ shape[0] = Tile::M/PartitionM;
1605
+ shape[1] = Tile::N/PartitionN;
1606
+ }
1607
+
1608
+ // underlying data
1609
+ T* data;
1610
+
1611
+ // partition dimensions
1612
+ int shape[2];
1613
+ };
1614
+
1615
+ template <typename Partition>
1616
+ inline int partition_size(const Partition& part)
1617
+ {
1618
+ return part.shape[0]*part.shape[1];
1619
+ }
1620
+
1621
+ // returns the x, y coordinates of a tile given a linear index
1622
+ template <typename Partition>
1623
+ inline void partition_coord(const Partition& part, const int t, int& i, int& j)
1624
+ {
1625
+ i = t/part.shape[1];
1626
+ j = t%part.shape[1];
1627
+ }
1628
+
1629
+ template <typename Partition>
1630
+ inline auto partition_load(const Partition& tile, int i, int j)
1631
+ {
1632
+ mat_t<Partition::M, Partition::N, typename Partition::T> out;
1633
+
1634
+ const int tile_i = i*Partition::M;
1635
+ const int tile_j = j*Partition::N;
1636
+
1637
+ WP_PRAGMA_UNROLL
1638
+ for (int i=0; i < Partition::M; ++i)
1639
+ {
1640
+ WP_PRAGMA_UNROLL
1641
+ for (int j=0; j < Partition::N; ++j)
1642
+ {
1643
+ out.data[i][j] = index(tile.data, tile_i + i, tile_j + j, Partition::Stride);
1644
+ }
1645
+ }
1646
+
1647
+ return out;
1648
+ }
1649
+
1650
+ template <typename Partition, typename Value>
1651
+ inline void partition_store(const Partition& tile, int i, int j, const Value& value)
1652
+ {
1653
+ const int tile_i = Partition::M*i;
1654
+ const int tile_j = Partition::N*j;
1655
+
1656
+ WP_PRAGMA_UNROLL
1657
+ for (int i=0; i < Partition::M; ++i)
1658
+ {
1659
+ WP_PRAGMA_UNROLL
1660
+ for (int j=0; j < Partition::N; ++j)
1661
+ {
1662
+ index(tile.data, tile_i + i, tile_j + j, Partition::Stride) = value.data[i][j];
1663
+ }
1664
+ }
1665
+ }
1666
+
1667
+ template <typename TileA, typename TileB, typename TileC>
1668
+ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
1669
+ {
1670
+ const int TILE_M = 4;
1671
+ const int TILE_N = 4;
1672
+ const int TILE_K = 4;
1673
+
1674
+ auto A_tile = partition_t<TILE_M, TILE_K, TileA>(A);
1675
+ auto B_tile = partition_t<TILE_K, TILE_N, TileB>(B);
1676
+ auto C_tile = partition_t<TILE_M, TILE_N, TileC>(out);
1677
+
1678
+ const int length = partition_size(C_tile);
1679
+
1680
+ for (int t=threadIdx.x; t < length; t += blockDim.x)
1681
+ {
1682
+ int i, j;
1683
+ partition_coord(C_tile, t, i, j);
1684
+
1685
+ // accumulator
1686
+ auto sum = partition_load(C_tile, i, j);
1687
+
1688
+ WP_PRAGMA_UNROLL
1689
+ for (int k=0; k < A_tile.shape[1]; k++)
1690
+ {
1691
+ const auto a = partition_load(A_tile, i, k);
1692
+ const auto b = partition_load(B_tile, k, j);
1693
+
1694
+ sum += mul(a, b);
1695
+ }
1696
+
1697
+ partition_store(C_tile, i, j, sum);
1698
+ }
1699
+ }
1700
+
1701
+ } // namespace partition_gemm
1702
+
1703
+ template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
1704
+ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C)
1705
+ {
1706
+ using T = typename TileA::Type;
1707
+
1708
+ #if WP_USE_ASYNC_PIPELINE
1709
+ __pipeline_wait_prior(0);
1710
+ WP_TILE_SYNC();
1711
+ #endif
1712
+
1713
+ #if WP_USE_REGISTER_GEMM
1714
+ partitioned_gemm::matmul(A, B, C);
1715
+ #else
1716
+ fun_forward(T(1.0), A.data.ptr, B.data.ptr, T(Add), C.data.ptr);
1717
+ #endif
1718
+
1719
+ WP_TILE_SYNC();
1720
+
1721
+ return C;
1722
+ }
1723
+
1724
+ // backward for the wp.tile_matmul(a, b, out) syntax
1725
+ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
1726
+ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
1727
+ Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C)
1728
+ {
1729
+ using T = typename TileA::Type;
1730
+
1731
+ fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
1732
+ fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
1733
+ WP_TILE_SYNC();
1734
+ }
1735
+
1736
+ // backward for the out = wp.tile_matmul(a, b) syntax
1737
+ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
1738
+ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
1739
+ Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C, TileC& adj_ret)
1740
+ {
1741
+ using T = typename TileA::Type;
1742
+
1743
+ fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
1744
+ fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
1745
+ WP_TILE_SYNC();
1746
+ }
1747
+
1748
+ // TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
1749
+ // TODO(lcambier): use dynamic smem
1750
+ #define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
1751
+ do { \
1752
+ void function_name(dtype*, dtype*); \
1753
+ WP_TILE_SHARED __align__(16) char buffer[shared_memory_size]; \
1754
+ __align__(16) dtype data[ept]; \
1755
+ for(int b = 0; b < (int)batch_size; b++) { \
1756
+ dtype* inout = Xinout.data + (int)b * (int)ept; \
1757
+ memcpy(data, inout, sizeof(dtype) * ept); \
1758
+ function_name(data, (dtype*)buffer); \
1759
+ memcpy(inout, data, sizeof(dtype) * ept); \
1760
+ WP_TILE_SYNC(); \
1761
+ } \
1762
+ } while (0)
1763
+
1764
+ #define tile_ifft tile_fft
1765
+
1766
+ // adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept are all ignored
1767
+
1768
+ #define adj_tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
1769
+ adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
1770
+ adj_Xinout) \
1771
+ do { \
1772
+ tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
1773
+ } while (0)
1774
+
1775
+ #define adj_tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
1776
+ adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
1777
+ adj_Xinout) \
1778
+ do { \
1779
+ tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
1780
+ } while (0)
1781
+
1782
+
1783
+ template <typename Tile>
1784
+ inline CUDA_CALLABLE auto tile_transpose(Tile& t)
1785
+ {
1786
+ // alias incoming tile
1787
+ return tile_shared_t<typename Tile::Type, Tile::N, Tile::M, Tile::StrideN, Tile::StrideM, false>(t.data.ptr, t.grad.ptr);
1788
+ }
1789
+
1790
+ template <typename Tile, typename AdjTile>
1791
+ inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
1792
+ {
1793
+ auto a = tile_transpose(adj_ret);
1794
+ auto b = adj_t;
1795
+
1796
+ adj_t.assign(tile_add(a,b));
1797
+ }
1798
+
1799
+ template <int M, int N, int StrideM, int StrideN, typename Tile>
1800
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
1801
+ {
1802
+ // alias incoming tile with new strides
1803
+ return tile_shared_t<typename Tile::Type, M, N, StrideM, StrideN, false>(t.data.ptr, t.grad.ptr);
1804
+ }
1805
+
1806
+ template <typename Tile, typename AdjTile>
1807
+ inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
1808
+ {
1809
+ // nop, since memory is aliased grads already accumulated
1810
+
1811
+ }
1812
+
1813
+ template <int M, int N, typename Tile>
1814
+ inline CUDA_CALLABLE auto tile_view(Tile& t, int i, int j)
1815
+ {
1816
+ // alias incoming tile with new strides
1817
+ return tile_shared_t<typename Tile::Type, M, N, Tile::StrideM, Tile::StrideN, false>(&t.data(i, j), &t.grad(i, j));
1818
+ }
1819
+
1820
+ template <typename Tile, typename AdjTile>
1821
+ inline CUDA_CALLABLE void adj_tile_view(Tile& t, int i, int j, Tile& adj_t, int adj_i, int adj_j, AdjTile& adj_ret)
1822
+ {
1823
+ // nop, since memory is aliased grads already accumulated
1824
+
1825
+ }
1826
+
1827
+ template <typename TileA, typename TileB>
1828
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, int i, int j, TileB& src)
1829
+ {
1830
+ for (int t=threadIdx.x; t < src.Size; t += WP_TILE_BLOCK_DIM)
1831
+ {
1832
+ coord_t c = src.coord(t);
1833
+ dest.data(i + c.i, j + c.j) = src.data(c.i, c.j);
1834
+ }
1835
+
1836
+ WP_TILE_SYNC();
1837
+ }
1838
+
1839
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
1840
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, int i, int j, TileB& src,
1841
+ AdjTileA& adj_dest, int adj_i, int adj_j, AdjTileB& adj_src)
1842
+ {
1843
+ for (int t=threadIdx.x; t < src.Size; t += WP_TILE_BLOCK_DIM)
1844
+ {
1845
+ coord_t c = src.coord(t);
1846
+ src.grad(c.i, c.j) += dest.grad(i + c.i, j + c.j);
1847
+ }
1848
+
1849
+ WP_TILE_SYNC();
1850
+ }
1851
+
1852
+
1853
+
1854
+ } // namespace wp