blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/softmax.py CHANGED
@@ -1,17 +1,20 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
8
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
7
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
12
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
13
+ validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
11
14
 
12
15
 
13
- def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
14
- triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
16
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
17
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
15
18
  """Computes the softmax of a block-sparse tensor in compressed form.
16
19
 
17
20
  Note:
@@ -21,7 +24,6 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
21
24
  x (BlksprsTensor): A block-sparse tensor in compressed form.
22
25
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
23
26
  sparsity_block_size (int): The size of the sparsity blocks.
24
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
25
27
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
26
28
 
27
29
  Returns:
@@ -32,102 +34,73 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
32
34
 
33
35
  validate_dimensions(x)
34
36
  validate_contiguous(x)
37
+ validate_dtype_float_32(x)
35
38
  validate_device(x)
36
39
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
37
40
  validate_sparsity_block_size(sparsity_block_size, x)
38
- validate_triton_block_size(triton_block_size, sparsity_block_size)
39
41
 
40
- lut = _BlocksparseSoftmax.build_lut(lut, sparsity_layout)
41
-
42
- return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
43
- lut["sparsity_lut"],
44
- lut["sparsity_reverse_lut_rws"],
45
- sparsity_block_size, triton_block_size))
46
-
47
-
48
- class _BlocksparseSoftmax(torch.autograd.Function):
49
-
50
- @staticmethod
51
- def build_lut(lut: dict, sparsity_layout: Tensor):
52
- if lut is None:
53
- lut = dict()
54
-
55
- if "sparsity_lut" not in lut:
56
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
57
- lut["sparsity_lut"] = sparsity_lut
58
-
59
-
60
- if "sparsity_reverse_lut_rws" not in lut:
61
- sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
62
- sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
63
- sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
64
- (sparsity_layout_rws_flat == 1) -
65
- (1 * (sparsity_layout_rws_flat == 0)))
66
- lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
67
-
68
- validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
69
-
70
- return lut
71
-
72
- @staticmethod
73
- def forward(ctx, x: Tensor, sparsity_layout: Tensor,
74
- sparsity_lut: Tensor,
75
- sparsity_reverse_lut_rws: Tensor,
76
- sparsity_block_size: int, triton_block_size: int) -> Tensor:
77
- output = torch.empty_like(x)
78
-
79
- x_b, x_r, x_c = x.size()
80
- x_b_s, x_r_s, x_c_s = stride(x)
81
- s_lut_r, s_lut_c = sparsity_lut.size()
82
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
83
- o_b, o_r, o_c = output.size()
84
-
85
- x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
86
- flag_slice_only=True,
87
- triton_block_size=triton_block_size)
88
- x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
89
- x_exp = torch.exp(x_scaled)
90
- x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
91
- flag_slice_only=True,
92
- triton_block_size=triton_block_size)
93
-
94
- s_b, s_r, s_c = x_exp_row_wise_sum.shape
95
- s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
96
- s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
97
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
98
-
99
- if triton_block_size is None:
100
- triton_block_size = get_triton_block_size(sparsity_block_size)
101
-
102
- triton_grid = lambda meta: [o_b,
103
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
104
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
105
-
106
- (_BlocksparseSoftmax.kernel_blocksparse_softmax[triton_grid]
107
- (x_exp,
108
- x_b, x_b_s, x_r_s, x_c_s,
109
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
110
- x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
111
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
112
- sparsity_reverse_lut_rws,
113
- output,
114
- triton_block_size))
115
-
116
- # Save for backward pass
117
- ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
118
- ctx.sparsity_block_size = sparsity_block_size
119
- ctx.triton_block_size = triton_block_size
120
-
121
- return output
122
-
123
- @staticmethod
124
- def backward(ctx, grad_output):
125
- o, sparsity_layout, sparsity_lut = ctx.saved_tensors
126
- sparsity_block_size = ctx.sparsity_block_size
127
- triton_block_size = ctx.triton_block_size
128
-
129
- s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True,
130
- triton_block_size=triton_block_size)
42
+ lut = softmax_build_lut(lut, sparsity_layout)
43
+
44
+ return BlksprsTensor(softmax_forward(x, sparsity_layout,
45
+ lut["sparsity_lut"],
46
+ lut["sparsity_reverse_lut_rws"],
47
+ sparsity_block_size))
48
+
49
+
50
+ @triton_op("blksprs::softmax_forward", mutates_args={})
51
+ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
52
+ sparsity_lut: Tensor,
53
+ sparsity_reverse_lut_rws: Tensor,
54
+ sparsity_block_size: int) -> Tensor:
55
+ output = torch.zeros_like(x)
56
+
57
+ x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
58
+ flag_slice_only=True)
59
+ x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
60
+ x_exp = torch.exp(x_scaled)
61
+ x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
62
+ flag_slice_only=True)
63
+
64
+ x_b, x_r, x_c = x.size()
65
+ x_b_s, x_r_s, x_c_s = stride(x)
66
+ s_lut_r, s_lut_c = sparsity_lut.size()
67
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
68
+ o_b, o_r, o_c = output.size()
69
+ s_b, s_r, s_c = x_exp_row_wise_sum.shape
70
+ s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
71
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
72
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
73
+
74
+ triton_grid = lambda meta: [o_b,
75
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
76
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
77
+
78
+ (wrap_triton(softmax_kernel)[triton_grid]
79
+ (x_exp,
80
+ x_b, x_b_s, x_r_s, x_c_s,
81
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
82
+ x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
83
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
84
+ sparsity_reverse_lut_rws,
85
+ output,
86
+ sparsity_block_size))
87
+
88
+ return output
89
+
90
+
91
+ def softmax_backward_wrapper(ctx, grad_output):
92
+ o, sparsity_layout, sparsity_lut = ctx.saved_tensors
93
+ sparsity_block_size = ctx.sparsity_block_size
94
+
95
+ return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
96
+ sparsity_block_size), None, None, None, None, None
97
+
98
+
99
+ @triton_op("blksprs::softmax_backward", mutates_args={})
100
+ def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
101
+ sparsity_block_size: int) -> Tensor:
102
+ with torch.no_grad():
103
+ s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
131
104
 
132
105
  sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
133
106
  sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
@@ -143,13 +116,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
143
116
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
144
117
  s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
145
118
 
146
- grad_x = torch.empty_like(o, dtype=torch.float)
119
+ grad_x = torch.zeros_like(o, dtype=torch.float)
147
120
 
148
121
  triton_grid = lambda meta: [o_b,
149
122
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
150
123
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
151
124
 
152
- (_BlocksparseSoftmax.kernel_blocksparse_softmax_grad_x[triton_grid]
125
+ (wrap_triton(softmax_kernel_grad)[triton_grid]
153
126
  (grad_output,
154
127
  o_b, o_b_s, o_r_s, o_c_s,
155
128
  o,
@@ -161,118 +134,171 @@ class _BlocksparseSoftmax(torch.autograd.Function):
161
134
  sparsity_reverse_lut_s,
162
135
  grad_x,
163
136
  o_b, o_b_s, o_r_s, o_c_s,
164
- triton_block_size
165
- ))
166
-
167
- return grad_x, None, None, None, None, None
168
-
169
- @staticmethod
170
- @triton.jit
171
- def kernel_blocksparse_softmax(x,
172
- x_b, x_b_s, x_r_s, x_c_s,
173
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
174
- s, s_b, s_b_s, s_r_s, s_c_s,
175
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
176
- r_lut_s,
177
- o,
178
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
179
- # Get triton block indices
180
- pid_blk = tl.program_id(axis=0)
181
- pid_row = tl.program_id(axis=1)
182
- pid_col = tl.program_id(axis=2)
183
-
184
- # Get position of current sparsity block consisting of its batch and row index
185
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
186
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
187
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
188
-
189
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
190
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
191
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
192
-
193
- # Get reverse sparsity indices for s
194
- rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
195
- spa_row * s_l_s_r_s)
196
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
197
- rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
198
-
199
- if rev_idx_spa_s >= 0:
200
- # Load x block
201
- blk_x_idx = ((pid_blk * x_b_s) +
202
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
203
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
204
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
205
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
206
-
207
- # Load sum block
208
- blk_s_idx = (rev_idx_spa_s * s_b_s +
209
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
210
- (tl.arange(0, 1) * s_c_s)[None, :])
211
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
212
- blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
213
-
214
- # Compute softmax
215
- buf = tl.div_rn(blk_x, blk_s)
216
-
217
- # Store output
218
- tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
219
-
220
- @staticmethod
221
- @triton.jit
222
- def kernel_blocksparse_softmax_grad_x(g,
223
- g_b, g_b_s, g_r_s, g_c_s,
224
- x,
225
- x_b, x_b_s, x_r_s, x_c_s,
226
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
227
- s,
228
- s_b, s_b_s, s_r_s, s_c_s,
229
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
230
- r_lut_s,
231
- o,
232
- o_b, o_b_s, o_r_s, o_c_s,
233
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
234
- # Get triton block indices
235
- pid_blk = tl.program_id(axis=0)
236
- pid_row = tl.program_id(axis=1)
237
- pid_col = tl.program_id(axis=2)
238
-
239
- # Get position of current sparsity block consisting of its batch and row index
240
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
241
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
242
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
243
-
244
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
245
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
246
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
247
-
248
- rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
249
- spa_row * s_l_s_r_s)
250
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
251
- rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
252
-
253
- if rev_idx_spa_s >= 0:
254
- blk_s_idx = (rev_idx_spa_s * s_b_s +
255
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
256
- (tl.arange(0, 1) * s_c_s)[None, :])
257
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
258
- blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
259
-
260
- blk_g_idx = ((pid_blk * g_b_s) +
261
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
262
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
263
- blk_g_msk = (blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
264
- blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
265
-
266
- blk_x_idx = ((pid_blk * x_b_s) +
267
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
268
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
269
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
270
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
271
-
272
- buf = blk_x * (blk_g - blk_s)
273
-
274
- blk_o_idx = ((pid_blk * o_b_s) +
275
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
276
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
277
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
278
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
137
+ sparsity_block_size))
138
+
139
+ return grad_x
140
+
141
+
142
+ # noinspection PyUnusedLocal
143
+ @triton.autotune(
144
+ configs=get_autotune_configs(),
145
+ key=["sparsity_block_size"],
146
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
147
+ reset_to_zero=["o"]
148
+ )
149
+ @triton.jit
150
+ def softmax_kernel(x,
151
+ x_b, x_b_s, x_r_s, x_c_s,
152
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
153
+ s, s_b, s_b_s, s_r_s, s_c_s,
154
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
155
+ r_lut_s,
156
+ o,
157
+ sparsity_block_size,
158
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
159
+ # Get triton block indices
160
+ pid_blk = tl.program_id(axis=0)
161
+ pid_row = tl.program_id(axis=1)
162
+ pid_col = tl.program_id(axis=2)
163
+
164
+ # Get position of current sparsity block consisting of its batch and row index
165
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
166
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
167
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
168
+
169
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
170
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
171
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
172
+
173
+ # Get reverse sparsity indices for s
174
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
175
+ spa_row * s_l_s_r_s)
176
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
177
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
178
+
179
+ if rev_idx_spa_s >= 0:
180
+ # Load x block
181
+ blk_x_idx = ((pid_blk * x_b_s) +
182
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
183
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
184
+ blk_x_msk = (blk_x_idx >= 0 and
185
+ blk_x_idx < x_b * x_b_s)
186
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
187
+
188
+ # Load sum block
189
+ blk_s_idx = (rev_idx_spa_s * s_b_s +
190
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
191
+ (tl.arange(0, 1) * s_c_s)[None, :])
192
+ blk_s_msk = (blk_s_idx >= 0 and
193
+ blk_s_idx < s_b * s_b_s)
194
+ blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
195
+
196
+ # Compute softmax
197
+ buf = tl.div_rn(blk_x, blk_s)
198
+
199
+ # Store output
200
+ tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
201
+
202
+
203
+ # noinspection PyUnusedLocal
204
+ @triton.autotune(
205
+ configs=get_autotune_configs(),
206
+ key=["sparsity_block_size"],
207
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
208
+ reset_to_zero=["o"]
209
+ )
210
+ @triton.jit
211
+ def softmax_kernel_grad(g,
212
+ g_b, g_b_s, g_r_s, g_c_s,
213
+ x,
214
+ x_b, x_b_s, x_r_s, x_c_s,
215
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
216
+ s,
217
+ s_b, s_b_s, s_r_s, s_c_s,
218
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
219
+ r_lut_s,
220
+ o,
221
+ o_b, o_b_s, o_r_s, o_c_s,
222
+ sparsity_block_size,
223
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
224
+ # Get triton block indices
225
+ pid_blk = tl.program_id(axis=0)
226
+ pid_row = tl.program_id(axis=1)
227
+ pid_col = tl.program_id(axis=2)
228
+
229
+ # Get position of current sparsity block consisting of its batch and row index
230
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
231
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
232
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
233
+
234
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
235
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
236
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
237
+
238
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
239
+ spa_row * s_l_s_r_s)
240
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
241
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
242
+
243
+ if rev_idx_spa_s >= 0:
244
+ blk_s_idx = (rev_idx_spa_s * s_b_s +
245
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
246
+ (tl.arange(0, 1) * s_c_s)[None, :])
247
+ blk_s_msk = (blk_s_idx >= 0 and
248
+ blk_s_idx < s_b * s_b_s)
249
+ blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
250
+
251
+ blk_g_idx = ((pid_blk * g_b_s) +
252
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
253
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
254
+ blk_g_msk = (blk_g_idx >= 0 and
255
+ blk_g_idx < g_b * g_b_s)
256
+ blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
257
+
258
+ blk_x_idx = ((pid_blk * x_b_s) +
259
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
260
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
261
+ blk_x_msk = (blk_x_idx >= 0 and
262
+ blk_x_idx < x_b * x_b_s)
263
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
264
+
265
+ buf = blk_x * (blk_g - blk_s)
266
+
267
+ blk_o_idx = ((pid_blk * o_b_s) +
268
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
269
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
270
+ blk_o_msk = (blk_o_idx >= 0 and
271
+ blk_o_idx < o_b * o_b_s)
272
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
273
+
274
+
275
+ def softmax_build_lut(lut: dict, sparsity_layout: Tensor):
276
+ if lut is None:
277
+ lut = dict()
278
+
279
+ if "sparsity_lut" not in lut:
280
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
281
+ lut["sparsity_lut"] = sparsity_lut
282
+
283
+ if "sparsity_reverse_lut_rws" not in lut:
284
+ sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
285
+ sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
286
+ sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
287
+ (sparsity_layout_rws_flat == 1) -
288
+ (1 * (sparsity_layout_rws_flat == 0)))
289
+ lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
290
+
291
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
292
+
293
+ return lut
294
+
295
+
296
+ # noinspection PyUnusedLocal
297
+ def softmax_setup_context(ctx, inputs, output):
298
+ (_, sparsity_layout, sparsity_lut, _, sparsity_block_size) = inputs
299
+
300
+ ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
301
+ ctx.sparsity_block_size = sparsity_block_size
302
+
303
+
304
+ softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)