blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/softmax.py CHANGED
@@ -1,18 +1,20 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.ops.misc.exp import exp
7
8
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import get_triton_block_size, stride
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
12
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
13
+ validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
12
14
 
13
15
 
14
- def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
15
- triton_block_size: int = None) -> BlksprsTensor:
16
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
17
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
16
18
  """Computes the softmax of a block-sparse tensor in compressed form.
17
19
 
18
20
  Note:
@@ -22,7 +24,7 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
22
24
  x (BlksprsTensor): A block-sparse tensor in compressed form.
23
25
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
26
  sparsity_block_size (int): The size of the sparsity blocks.
25
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
27
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
26
28
 
27
29
  Returns:
28
30
  BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
@@ -32,88 +34,73 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
32
34
 
33
35
  validate_dimensions(x)
34
36
  validate_contiguous(x)
37
+ validate_dtype_float_32(x)
35
38
  validate_device(x)
36
39
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
37
40
  validate_sparsity_block_size(sparsity_block_size, x)
38
- validate_triton_block_size(triton_block_size, sparsity_block_size)
39
41
 
40
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
41
-
42
- sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
43
- sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
44
- sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
45
- (sparsity_layout_rws_flat == 1) -
46
- (1 * (sparsity_layout_rws_flat == 0)))
47
-
48
- validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
49
-
50
- return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
51
- sparsity_lut,
52
- sparsity_reverse_lut_rws,
53
- sparsity_block_size, triton_block_size))
54
-
55
-
56
- class _BlocksparseSoftmax(torch.autograd.Function):
57
-
58
- @staticmethod
59
- def forward(ctx, x: Tensor, sparsity_layout: Tensor,
60
- sparsity_lut: Tensor,
61
- sparsity_reverse_lut_rws: Tensor,
62
- sparsity_block_size: int, triton_block_size: int) -> Tensor:
63
- output = torch.empty_like(x)
64
-
65
- x_b, x_r, x_c = x.size()
66
- x_b_s, x_r_s, x_c_s = stride(x)
67
- s_lut_r, s_lut_c = sparsity_lut.size()
68
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
69
- o_b, o_r, o_c = output.size()
70
-
71
- x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
72
- flag_slice_only=True,
73
- triton_block_size=triton_block_size)
74
- x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
75
- x_exp = exp(x_scaled, sparsity_block_size, triton_block_size=triton_block_size)
76
- x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
77
- flag_slice_only=True,
78
- triton_block_size=triton_block_size)
79
-
80
- s_b, s_r, s_c = x_exp_row_wise_sum.shape
81
- s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
82
- s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
83
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
84
-
85
- if triton_block_size is None:
86
- triton_block_size = get_triton_block_size(sparsity_block_size)
87
-
88
- triton_grid = lambda meta: [o_b,
89
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
90
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
91
-
92
- (_BlocksparseSoftmax.kernel_blocksparse_softmax[triton_grid]
93
- (x_exp,
94
- x_b, x_b_s, x_r_s, x_c_s,
95
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
96
- x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
97
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
98
- sparsity_reverse_lut_rws,
99
- output,
100
- triton_block_size))
101
-
102
- # Save for backward pass
103
- ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
104
- ctx.sparsity_block_size = sparsity_block_size
105
- ctx.triton_block_size = triton_block_size
106
-
107
- return output
108
-
109
- @staticmethod
110
- def backward(ctx, grad_output):
111
- o, sparsity_layout, sparsity_lut = ctx.saved_tensors
112
- sparsity_block_size = ctx.sparsity_block_size
113
- triton_block_size = ctx.triton_block_size
114
-
115
- s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True,
116
- triton_block_size=triton_block_size)
42
+ lut = softmax_build_lut(lut, sparsity_layout)
43
+
44
+ return BlksprsTensor(softmax_forward(x, sparsity_layout,
45
+ lut["sparsity_lut"],
46
+ lut["sparsity_reverse_lut_rws"],
47
+ sparsity_block_size))
48
+
49
+
50
+ @triton_op("blksprs::softmax_forward", mutates_args={})
51
+ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
52
+ sparsity_lut: Tensor,
53
+ sparsity_reverse_lut_rws: Tensor,
54
+ sparsity_block_size: int) -> Tensor:
55
+ output = torch.zeros_like(x)
56
+
57
+ x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
58
+ flag_slice_only=True)
59
+ x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
60
+ x_exp = torch.exp(x_scaled)
61
+ x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
62
+ flag_slice_only=True)
63
+
64
+ x_b, x_r, x_c = x.size()
65
+ x_b_s, x_r_s, x_c_s = stride(x)
66
+ s_lut_r, s_lut_c = sparsity_lut.size()
67
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
68
+ o_b, o_r, o_c = output.size()
69
+ s_b, s_r, s_c = x_exp_row_wise_sum.shape
70
+ s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
71
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
72
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
73
+
74
+ triton_grid = lambda meta: [o_b,
75
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
76
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
77
+
78
+ (wrap_triton(softmax_kernel)[triton_grid]
79
+ (x_exp,
80
+ x_b, x_b_s, x_r_s, x_c_s,
81
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
82
+ x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
83
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
84
+ sparsity_reverse_lut_rws,
85
+ output,
86
+ sparsity_block_size))
87
+
88
+ return output
89
+
90
+
91
+ def softmax_backward_wrapper(ctx, grad_output):
92
+ o, sparsity_layout, sparsity_lut = ctx.saved_tensors
93
+ sparsity_block_size = ctx.sparsity_block_size
94
+
95
+ return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
96
+ sparsity_block_size), None, None, None, None, None
97
+
98
+
99
+ @triton_op("blksprs::softmax_backward", mutates_args={})
100
+ def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
101
+ sparsity_block_size: int) -> Tensor:
102
+ with torch.no_grad():
103
+ s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
117
104
 
118
105
  sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
119
106
  sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
@@ -129,13 +116,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
129
116
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
130
117
  s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
131
118
 
132
- grad_x = torch.empty_like(o, dtype=torch.float)
119
+ grad_x = torch.zeros_like(o, dtype=torch.float)
133
120
 
134
121
  triton_grid = lambda meta: [o_b,
135
122
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
136
123
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
137
124
 
138
- (_BlocksparseSoftmax.kernel_blocksparse_softmax_grad_x[triton_grid]
125
+ (wrap_triton(softmax_kernel_grad)[triton_grid]
139
126
  (grad_output,
140
127
  o_b, o_b_s, o_r_s, o_c_s,
141
128
  o,
@@ -147,57 +134,63 @@ class _BlocksparseSoftmax(torch.autograd.Function):
147
134
  sparsity_reverse_lut_s,
148
135
  grad_x,
149
136
  o_b, o_b_s, o_r_s, o_c_s,
150
- triton_block_size
151
- ))
152
-
153
- return grad_x, None, None, None, None, None
154
-
155
- @staticmethod
156
- @triton.jit
157
- def kernel_blocksparse_softmax(x,
158
- x_b, x_b_s, x_r_s, x_c_s,
159
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
160
- s, s_b, s_b_s, s_r_s, s_c_s,
161
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
162
- r_lut_s,
163
- o,
164
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
165
- # Get triton block indices
166
- pid_blk = tl.program_id(axis=0)
167
- pid_row = tl.program_id(axis=1)
168
- pid_col = tl.program_id(axis=2)
169
-
170
- # Get position of current sparsity block consisting of its batch and row index
171
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
172
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
173
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
174
-
175
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
176
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
177
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
178
-
179
- # Get reverse sparsity indices for s
180
- rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
181
- spa_row * s_l_s_r_s)
182
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
183
- rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
184
-
185
- if rev_idx_spa_s == -1:
186
- tl.device_assert(False)
187
- return
188
-
137
+ sparsity_block_size))
138
+
139
+ return grad_x
140
+
141
+
142
+ # noinspection PyUnusedLocal
143
+ @triton.autotune(
144
+ configs=get_autotune_configs(),
145
+ key=["sparsity_block_size"],
146
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
147
+ reset_to_zero=["o"]
148
+ )
149
+ @triton.jit
150
+ def softmax_kernel(x,
151
+ x_b, x_b_s, x_r_s, x_c_s,
152
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
153
+ s, s_b, s_b_s, s_r_s, s_c_s,
154
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
155
+ r_lut_s,
156
+ o,
157
+ sparsity_block_size,
158
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
159
+ # Get triton block indices
160
+ pid_blk = tl.program_id(axis=0)
161
+ pid_row = tl.program_id(axis=1)
162
+ pid_col = tl.program_id(axis=2)
163
+
164
+ # Get position of current sparsity block consisting of its batch and row index
165
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
166
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
167
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
168
+
169
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
170
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
171
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
172
+
173
+ # Get reverse sparsity indices for s
174
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
175
+ spa_row * s_l_s_r_s)
176
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
177
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
178
+
179
+ if rev_idx_spa_s >= 0:
189
180
  # Load x block
190
181
  blk_x_idx = ((pid_blk * x_b_s) +
191
182
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
192
183
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
193
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
184
+ blk_x_msk = (blk_x_idx >= 0 and
185
+ blk_x_idx < x_b * x_b_s)
194
186
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
195
187
 
196
188
  # Load sum block
197
189
  blk_s_idx = (rev_idx_spa_s * s_b_s +
198
190
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
199
191
  (tl.arange(0, 1) * s_c_s)[None, :])
200
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
192
+ blk_s_msk = (blk_s_idx >= 0 and
193
+ blk_s_idx < s_b * s_b_s)
201
194
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
202
195
 
203
196
  # Compute softmax
@@ -206,59 +199,67 @@ class _BlocksparseSoftmax(torch.autograd.Function):
206
199
  # Store output
207
200
  tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
208
201
 
209
- @staticmethod
210
- @triton.jit
211
- def kernel_blocksparse_softmax_grad_x(g,
212
- g_b, g_b_s, g_r_s, g_c_s,
213
- x,
214
- x_b, x_b_s, x_r_s, x_c_s,
215
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
216
- s,
217
- s_b, s_b_s, s_r_s, s_c_s,
218
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
219
- r_lut_s,
220
- o,
221
- o_b, o_b_s, o_r_s, o_c_s,
222
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
223
- # Get triton block indices
224
- pid_blk = tl.program_id(axis=0)
225
- pid_row = tl.program_id(axis=1)
226
- pid_col = tl.program_id(axis=2)
227
-
228
- # Get position of current sparsity block consisting of its batch and row index
229
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
230
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
231
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
232
-
233
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
234
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
235
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
236
-
237
- rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
238
- spa_row * s_l_s_r_s)
239
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
240
- rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
241
-
242
- if rev_idx_spa_s == -1:
243
- tl.device_assert(False)
244
- return
245
202
 
203
+ # noinspection PyUnusedLocal
204
+ @triton.autotune(
205
+ configs=get_autotune_configs(),
206
+ key=["sparsity_block_size"],
207
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
208
+ reset_to_zero=["o"]
209
+ )
210
+ @triton.jit
211
+ def softmax_kernel_grad(g,
212
+ g_b, g_b_s, g_r_s, g_c_s,
213
+ x,
214
+ x_b, x_b_s, x_r_s, x_c_s,
215
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
216
+ s,
217
+ s_b, s_b_s, s_r_s, s_c_s,
218
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
219
+ r_lut_s,
220
+ o,
221
+ o_b, o_b_s, o_r_s, o_c_s,
222
+ sparsity_block_size,
223
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
224
+ # Get triton block indices
225
+ pid_blk = tl.program_id(axis=0)
226
+ pid_row = tl.program_id(axis=1)
227
+ pid_col = tl.program_id(axis=2)
228
+
229
+ # Get position of current sparsity block consisting of its batch and row index
230
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
231
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
232
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
233
+
234
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
235
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
236
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
237
+
238
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
239
+ spa_row * s_l_s_r_s)
240
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
241
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
242
+
243
+ if rev_idx_spa_s >= 0:
246
244
  blk_s_idx = (rev_idx_spa_s * s_b_s +
247
245
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
248
246
  (tl.arange(0, 1) * s_c_s)[None, :])
249
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
247
+ blk_s_msk = (blk_s_idx >= 0 and
248
+ blk_s_idx < s_b * s_b_s)
250
249
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
251
250
 
252
251
  blk_g_idx = ((pid_blk * g_b_s) +
253
252
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
254
253
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
255
- blk_g_msk = (blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
254
+ blk_g_msk = (blk_g_idx >= 0 and
255
+ blk_g_idx < g_b * g_b_s)
256
256
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
257
257
 
258
258
  blk_x_idx = ((pid_blk * x_b_s) +
259
259
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
260
260
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
261
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
261
+ blk_x_msk = (blk_x_idx >= 0 and
262
+ blk_x_idx < x_b * x_b_s)
262
263
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
263
264
 
264
265
  buf = blk_x * (blk_g - blk_s)
@@ -266,5 +267,38 @@ class _BlocksparseSoftmax(torch.autograd.Function):
266
267
  blk_o_idx = ((pid_blk * o_b_s) +
267
268
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
268
269
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
269
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
270
+ blk_o_msk = (blk_o_idx >= 0 and
271
+ blk_o_idx < o_b * o_b_s)
270
272
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
273
+
274
+
275
+ def softmax_build_lut(lut: dict, sparsity_layout: Tensor):
276
+ if lut is None:
277
+ lut = dict()
278
+
279
+ if "sparsity_lut" not in lut:
280
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
281
+ lut["sparsity_lut"] = sparsity_lut
282
+
283
+ if "sparsity_reverse_lut_rws" not in lut:
284
+ sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
285
+ sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
286
+ sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
287
+ (sparsity_layout_rws_flat == 1) -
288
+ (1 * (sparsity_layout_rws_flat == 0)))
289
+ lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
290
+
291
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
292
+
293
+ return lut
294
+
295
+
296
+ # noinspection PyUnusedLocal
297
+ def softmax_setup_context(ctx, inputs, output):
298
+ (_, sparsity_layout, sparsity_lut, _, sparsity_block_size) = inputs
299
+
300
+ ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
301
+ ctx.sparsity_block_size = sparsity_block_size
302
+
303
+
304
+ softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)