blksprs 1.11__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/softmax.py CHANGED
@@ -1,17 +1,18 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
8
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
7
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
10
+ from blksprs.utils.tools import stride, get_autotune_configs
9
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
12
+ validate_sparsity, validate_sparsity_block_size
11
13
 
12
14
 
13
- def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
14
- triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
15
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
15
16
  """Computes the softmax of a block-sparse tensor in compressed form.
16
17
 
17
18
  Note:
@@ -21,7 +22,6 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
21
22
  x (BlksprsTensor): A block-sparse tensor in compressed form.
22
23
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
23
24
  sparsity_block_size (int): The size of the sparsity blocks.
24
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
25
25
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
26
26
 
27
27
  Returns:
@@ -35,244 +35,272 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
35
35
  validate_device(x)
36
36
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
37
37
  validate_sparsity_block_size(sparsity_block_size, x)
38
- validate_triton_block_size(triton_block_size, sparsity_block_size)
39
-
40
- lut = _BlocksparseSoftmax.build_lut(lut, sparsity_layout)
41
-
42
- return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
43
- lut["sparsity_lut"],
44
- lut["sparsity_reverse_lut_rws"],
45
- sparsity_block_size, triton_block_size))
46
-
47
-
48
- class _BlocksparseSoftmax(torch.autograd.Function):
49
-
50
- @staticmethod
51
- def build_lut(lut: dict, sparsity_layout: Tensor):
52
- if lut is None:
53
- lut = dict()
54
-
55
- if "sparsity_lut" not in lut:
56
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
57
- lut["sparsity_lut"] = sparsity_lut
58
-
59
-
60
- if "sparsity_reverse_lut_rws" not in lut:
61
- sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
62
- sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
63
- sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
64
- (sparsity_layout_rws_flat == 1) -
65
- (1 * (sparsity_layout_rws_flat == 0)))
66
- lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
67
-
68
- validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
69
-
70
- return lut
71
-
72
- @staticmethod
73
- def forward(ctx, x: Tensor, sparsity_layout: Tensor,
74
- sparsity_lut: Tensor,
75
- sparsity_reverse_lut_rws: Tensor,
76
- sparsity_block_size: int, triton_block_size: int) -> Tensor:
77
- output = torch.empty_like(x)
78
-
79
- x_b, x_r, x_c = x.size()
80
- x_b_s, x_r_s, x_c_s = stride(x)
81
- s_lut_r, s_lut_c = sparsity_lut.size()
82
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
83
- o_b, o_r, o_c = output.size()
84
-
85
- x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
86
- flag_slice_only=True,
87
- triton_block_size=triton_block_size)
88
- x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
89
- x_exp = torch.exp(x_scaled)
90
- x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
91
- flag_slice_only=True,
92
- triton_block_size=triton_block_size)
93
-
94
- s_b, s_r, s_c = x_exp_row_wise_sum.shape
95
- s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
96
- s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
97
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
98
-
99
- if triton_block_size is None:
100
- triton_block_size = get_triton_block_size(sparsity_block_size)
101
-
102
- triton_grid = lambda meta: [o_b,
103
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
104
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
105
-
106
- (_BlocksparseSoftmax.kernel_blocksparse_softmax[triton_grid]
107
- (x_exp,
108
- x_b, x_b_s, x_r_s, x_c_s,
109
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
110
- x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
111
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
112
- sparsity_reverse_lut_rws,
113
- output,
114
- triton_block_size))
115
-
116
- # Save for backward pass
117
- ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
118
- ctx.sparsity_block_size = sparsity_block_size
119
- ctx.triton_block_size = triton_block_size
120
-
121
- return output
122
-
123
- @staticmethod
124
- def backward(ctx, grad_output):
125
- o, sparsity_layout, sparsity_lut = ctx.saved_tensors
126
- sparsity_block_size = ctx.sparsity_block_size
127
- triton_block_size = ctx.triton_block_size
128
-
129
- s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True,
130
- triton_block_size=triton_block_size)
131
-
132
- sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
133
- sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
134
- (sparsity_layout_s_flat == 1) -
135
- (1 * (sparsity_layout_s_flat == 0)))
136
-
137
- o_b, o_r, o_c = o.size()
138
- o_b_s, o_r_s, o_c_s = stride(o)
139
- s_lut_r, s_lut_c = sparsity_lut.size()
140
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
141
- s_b, s_r, s_c = s.size()
142
- s_b_s, s_r_s, s_c_s = stride(s)
143
- s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
144
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
145
-
146
- grad_x = torch.empty_like(o, dtype=torch.float)
147
-
148
- triton_grid = lambda meta: [o_b,
149
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
150
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
151
-
152
- (_BlocksparseSoftmax.kernel_blocksparse_softmax_grad_x[triton_grid]
153
- (grad_output,
154
- o_b, o_b_s, o_r_s, o_c_s,
155
- o,
156
- o_b, o_b_s, o_r_s, o_c_s,
157
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
158
- s,
159
- s_b, s_b_s, s_r_s, s_c_s,
160
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
161
- sparsity_reverse_lut_s,
162
- grad_x,
163
- o_b, o_b_s, o_r_s, o_c_s,
164
- triton_block_size
165
- ))
166
-
167
- return grad_x, None, None, None, None, None
168
-
169
- @staticmethod
170
- @triton.jit
171
- def kernel_blocksparse_softmax(x,
172
- x_b, x_b_s, x_r_s, x_c_s,
173
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
174
- s, s_b, s_b_s, s_r_s, s_c_s,
175
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
176
- r_lut_s,
177
- o,
178
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
179
- # Get triton block indices
180
- pid_blk = tl.program_id(axis=0)
181
- pid_row = tl.program_id(axis=1)
182
- pid_col = tl.program_id(axis=2)
183
-
184
- # Get position of current sparsity block consisting of its batch and row index
185
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
186
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
187
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
188
-
189
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
190
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
191
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
192
-
193
- # Get reverse sparsity indices for s
194
- rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
195
- spa_row * s_l_s_r_s)
196
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
197
- rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
198
-
199
- if rev_idx_spa_s >= 0:
200
- # Load x block
201
- blk_x_idx = ((pid_blk * x_b_s) +
202
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
203
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
204
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
205
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
206
-
207
- # Load sum block
208
- blk_s_idx = (rev_idx_spa_s * s_b_s +
209
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
210
- (tl.arange(0, 1) * s_c_s)[None, :])
211
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
212
- blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
213
-
214
- # Compute softmax
215
- buf = tl.div_rn(blk_x, blk_s)
216
-
217
- # Store output
218
- tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
219
-
220
- @staticmethod
221
- @triton.jit
222
- def kernel_blocksparse_softmax_grad_x(g,
223
- g_b, g_b_s, g_r_s, g_c_s,
224
- x,
225
- x_b, x_b_s, x_r_s, x_c_s,
226
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
227
- s,
228
- s_b, s_b_s, s_r_s, s_c_s,
229
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
230
- r_lut_s,
231
- o,
232
- o_b, o_b_s, o_r_s, o_c_s,
233
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
234
- # Get triton block indices
235
- pid_blk = tl.program_id(axis=0)
236
- pid_row = tl.program_id(axis=1)
237
- pid_col = tl.program_id(axis=2)
238
-
239
- # Get position of current sparsity block consisting of its batch and row index
240
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
241
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
242
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
243
-
244
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
245
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
246
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
247
-
248
- rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
249
- spa_row * s_l_s_r_s)
250
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
251
- rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
252
-
253
- if rev_idx_spa_s >= 0:
254
- blk_s_idx = (rev_idx_spa_s * s_b_s +
255
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
256
- (tl.arange(0, 1) * s_c_s)[None, :])
257
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
258
- blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
259
-
260
- blk_g_idx = ((pid_blk * g_b_s) +
261
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
262
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
263
- blk_g_msk = (blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
264
- blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
265
-
266
- blk_x_idx = ((pid_blk * x_b_s) +
267
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
268
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
269
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
270
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
271
-
272
- buf = blk_x * (blk_g - blk_s)
273
-
274
- blk_o_idx = ((pid_blk * o_b_s) +
275
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
276
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
277
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
278
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
38
+
39
+ lut = softmax_build_lut(lut, sparsity_layout)
40
+
41
+ return BlksprsTensor(softmax_forward(x, sparsity_layout,
42
+ lut["sparsity_lut"],
43
+ lut["sparsity_reverse_lut_rws"],
44
+ sparsity_block_size))
45
+
46
+
47
+ @triton_op("blksprs::softmax", mutates_args={})
48
+ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
49
+ sparsity_lut: Tensor,
50
+ sparsity_reverse_lut_rws: Tensor,
51
+ sparsity_block_size: int) -> Tensor:
52
+ output = torch.empty_like(x)
53
+
54
+ x_b, x_r, x_c = x.size()
55
+ x_b_s, x_r_s, x_c_s = stride(x)
56
+ s_lut_r, s_lut_c = sparsity_lut.size()
57
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
58
+ o_b, o_r, o_c = output.size()
59
+
60
+ x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
61
+ flag_slice_only=True)
62
+ x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
63
+ x_exp = torch.exp(x_scaled)
64
+ x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
65
+ flag_slice_only=True)
66
+
67
+ s_b, s_r, s_c = x_exp_row_wise_sum.shape
68
+ s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
69
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
70
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
71
+
72
+ triton_grid = lambda meta: [o_b,
73
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
74
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
75
+
76
+ (wrap_triton(softmax_kernel)[triton_grid]
77
+ (x_exp,
78
+ x_b, x_b_s, x_r_s, x_c_s,
79
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
80
+ x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
81
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
82
+ sparsity_reverse_lut_rws,
83
+ output,
84
+ sparsity_block_size))
85
+
86
+ return output
87
+
88
+
89
+ def softmax_backward(ctx, grad_output):
90
+ o, sparsity_layout, sparsity_lut = ctx.saved_tensors
91
+ sparsity_block_size = ctx.sparsity_block_size
92
+
93
+ s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
94
+
95
+ sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
96
+ sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
97
+ (sparsity_layout_s_flat == 1) -
98
+ (1 * (sparsity_layout_s_flat == 0)))
99
+
100
+ o_b, o_r, o_c = o.size()
101
+ o_b_s, o_r_s, o_c_s = stride(o)
102
+ s_lut_r, s_lut_c = sparsity_lut.size()
103
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
104
+ s_b, s_r, s_c = s.size()
105
+ s_b_s, s_r_s, s_c_s = stride(s)
106
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
107
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
108
+
109
+ grad_x = torch.empty_like(o, dtype=torch.float)
110
+
111
+ triton_grid = lambda meta: [o_b,
112
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
113
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
114
+
115
+ (wrap_triton(softmax_kernel_grad)[triton_grid]
116
+ (grad_output,
117
+ o_b, o_b_s, o_r_s, o_c_s,
118
+ o,
119
+ o_b, o_b_s, o_r_s, o_c_s,
120
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
121
+ s,
122
+ s_b, s_b_s, s_r_s, s_c_s,
123
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
124
+ sparsity_reverse_lut_s,
125
+ grad_x,
126
+ o_b, o_b_s, o_r_s, o_c_s,
127
+ sparsity_block_size))
128
+
129
+ return grad_x, None, None, None, None, None
130
+
131
+
132
+ @triton.autotune(
133
+ configs=get_autotune_configs(),
134
+ key=[]
135
+ )
136
+ @triton.jit
137
+ def softmax_kernel(x,
138
+ x_b, x_b_s, x_r_s, x_c_s,
139
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
140
+ s, s_b, s_b_s, s_r_s, s_c_s,
141
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
142
+ r_lut_s,
143
+ o,
144
+ sparsity_block_size,
145
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
146
+ # Get triton block indices
147
+ pid_blk = tl.program_id(axis=0)
148
+ pid_row = tl.program_id(axis=1)
149
+ pid_col = tl.program_id(axis=2)
150
+
151
+ # Get valid triton block size
152
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
153
+
154
+ # Get position of current sparsity block consisting of its batch and row index
155
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
156
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
157
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
158
+
159
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
160
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
161
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
162
+
163
+ # Get reverse sparsity indices for s
164
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
165
+ spa_row * s_l_s_r_s)
166
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
167
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
168
+
169
+ if rev_idx_spa_s >= 0:
170
+ # Load x block
171
+ blk_x_idx = ((pid_blk * x_b_s) +
172
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
173
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
174
+ blk_x_msk = ((blk_x_idx >= 0 and
175
+ blk_x_idx < x_b * x_b_s) and
176
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
177
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
178
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
179
+
180
+ # Load sum block
181
+ blk_s_idx = (rev_idx_spa_s * s_b_s +
182
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
183
+ (tl.arange(0, 1) * s_c_s)[None, :])
184
+ blk_s_msk = ((blk_s_idx >= 0 and
185
+ blk_s_idx < s_b * s_b_s) and
186
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
187
+ tl.arange(0, 1)[None, :] < val_tbs))
188
+ blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
189
+
190
+ # Compute softmax
191
+ buf = tl.div_rn(blk_x, blk_s)
192
+
193
+ # Store output
194
+ tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
195
+
196
+
197
+ @triton.autotune(
198
+ configs=get_autotune_configs(),
199
+ key=[]
200
+ )
201
+ @triton.jit
202
+ def softmax_kernel_grad(g,
203
+ g_b, g_b_s, g_r_s, g_c_s,
204
+ x,
205
+ x_b, x_b_s, x_r_s, x_c_s,
206
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
207
+ s,
208
+ s_b, s_b_s, s_r_s, s_c_s,
209
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
210
+ r_lut_s,
211
+ o,
212
+ o_b, o_b_s, o_r_s, o_c_s,
213
+ sparsity_block_size,
214
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
215
+ # Get triton block indices
216
+ pid_blk = tl.program_id(axis=0)
217
+ pid_row = tl.program_id(axis=1)
218
+ pid_col = tl.program_id(axis=2)
219
+
220
+ # Get valid triton block size
221
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
222
+
223
+ # Get position of current sparsity block consisting of its batch and row index
224
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
225
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
226
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
227
+
228
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
229
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
230
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
231
+
232
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
233
+ spa_row * s_l_s_r_s)
234
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
235
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
236
+
237
+ if rev_idx_spa_s >= 0:
238
+ blk_s_idx = (rev_idx_spa_s * s_b_s +
239
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
240
+ (tl.arange(0, 1) * s_c_s)[None, :])
241
+ blk_s_msk = ((blk_s_idx >= 0 and
242
+ blk_s_idx < s_b * s_b_s) and
243
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
244
+ tl.arange(0, 1)[None, :] < val_tbs))
245
+ blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
246
+
247
+ blk_g_idx = ((pid_blk * g_b_s) +
248
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
249
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
250
+ blk_g_msk = ((blk_g_idx >= 0 and
251
+ blk_g_idx < g_b * g_b_s) and
252
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
253
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
254
+ blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
255
+
256
+ blk_x_idx = ((pid_blk * x_b_s) +
257
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
258
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
259
+ blk_x_msk = ((blk_x_idx >= 0 and
260
+ blk_x_idx < x_b * x_b_s) and
261
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
262
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
263
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
264
+
265
+ buf = blk_x * (blk_g - blk_s)
266
+
267
+ blk_o_idx = ((pid_blk * o_b_s) +
268
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
269
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
270
+ blk_o_msk = ((blk_o_idx >= 0 and
271
+ blk_o_idx < o_b * o_b_s) and
272
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
273
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
274
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
275
+
276
+
277
+ def softmax_build_lut(lut: dict, sparsity_layout: Tensor):
278
+ if lut is None:
279
+ lut = dict()
280
+
281
+ if "sparsity_lut" not in lut:
282
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
283
+ lut["sparsity_lut"] = sparsity_lut
284
+
285
+ if "sparsity_reverse_lut_rws" not in lut:
286
+ sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
287
+ sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
288
+ sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
289
+ (sparsity_layout_rws_flat == 1) -
290
+ (1 * (sparsity_layout_rws_flat == 0)))
291
+ lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
292
+
293
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
294
+
295
+ return lut
296
+
297
+
298
+ # noinspection PyUnusedLocal
299
+ def softmax_setup_context(ctx, inputs, output):
300
+ (_, sparsity_layout, sparsity_lut, _, sparsity_block_size) = inputs
301
+
302
+ ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
303
+ ctx.sparsity_block_size = sparsity_block_size
304
+
305
+
306
+ softmax_forward.register_autograd(softmax_backward, setup_context=softmax_setup_context)