blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/softmax.py CHANGED
@@ -1,18 +1,18 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.ops.misc.exp import exp
7
8
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import get_triton_block_size, stride
10
+ from blksprs.utils.tools import stride, get_autotune_configs
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
12
+ validate_sparsity, validate_sparsity_block_size
12
13
 
13
14
 
14
- def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
15
- triton_block_size: int = None) -> BlksprsTensor:
15
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
16
16
  """Computes the softmax of a block-sparse tensor in compressed form.
17
17
 
18
18
  Note:
@@ -22,7 +22,7 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
22
22
  x (BlksprsTensor): A block-sparse tensor in compressed form.
23
23
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
24
  sparsity_block_size (int): The size of the sparsity blocks.
25
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
25
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
26
26
 
27
27
  Returns:
28
28
  BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
@@ -35,169 +35,156 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
35
35
  validate_device(x)
36
36
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
37
37
  validate_sparsity_block_size(sparsity_block_size, x)
38
- validate_triton_block_size(triton_block_size, sparsity_block_size)
39
-
40
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
41
-
42
- sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
43
- sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
44
- sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
45
- (sparsity_layout_rws_flat == 1) -
46
- (1 * (sparsity_layout_rws_flat == 0)))
47
-
48
- validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
49
-
50
- return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
51
- sparsity_lut,
52
- sparsity_reverse_lut_rws,
53
- sparsity_block_size, triton_block_size))
54
-
55
-
56
- class _BlocksparseSoftmax(torch.autograd.Function):
57
-
58
- @staticmethod
59
- def forward(ctx, x: Tensor, sparsity_layout: Tensor,
60
- sparsity_lut: Tensor,
61
- sparsity_reverse_lut_rws: Tensor,
62
- sparsity_block_size: int, triton_block_size: int) -> Tensor:
63
- output = torch.empty_like(x)
64
-
65
- x_b, x_r, x_c = x.size()
66
- x_b_s, x_r_s, x_c_s = stride(x)
67
- s_lut_r, s_lut_c = sparsity_lut.size()
68
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
69
- o_b, o_r, o_c = output.size()
70
-
71
- x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
72
- flag_slice_only=True,
73
- triton_block_size=triton_block_size)
74
- x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
75
- x_exp = exp(x_scaled, sparsity_block_size, triton_block_size=triton_block_size)
76
- x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
77
- flag_slice_only=True,
78
- triton_block_size=triton_block_size)
79
-
80
- s_b, s_r, s_c = x_exp_row_wise_sum.shape
81
- s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
82
- s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
83
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
84
-
85
- if triton_block_size is None:
86
- triton_block_size = get_triton_block_size(sparsity_block_size)
87
-
88
- triton_grid = lambda meta: [o_b,
89
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
90
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
91
-
92
- (_BlocksparseSoftmax.kernel_blocksparse_softmax[triton_grid]
93
- (x_exp,
94
- x_b, x_b_s, x_r_s, x_c_s,
95
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
96
- x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
97
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
98
- sparsity_reverse_lut_rws,
99
- output,
100
- triton_block_size))
101
-
102
- # Save for backward pass
103
- ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
104
- ctx.sparsity_block_size = sparsity_block_size
105
- ctx.triton_block_size = triton_block_size
106
-
107
- return output
108
-
109
- @staticmethod
110
- def backward(ctx, grad_output):
111
- o, sparsity_layout, sparsity_lut = ctx.saved_tensors
112
- sparsity_block_size = ctx.sparsity_block_size
113
- triton_block_size = ctx.triton_block_size
114
-
115
- s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True,
116
- triton_block_size=triton_block_size)
117
-
118
- sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
119
- sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
120
- (sparsity_layout_s_flat == 1) -
121
- (1 * (sparsity_layout_s_flat == 0)))
122
-
123
- o_b, o_r, o_c = o.size()
124
- o_b_s, o_r_s, o_c_s = stride(o)
125
- s_lut_r, s_lut_c = sparsity_lut.size()
126
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
127
- s_b, s_r, s_c = s.size()
128
- s_b_s, s_r_s, s_c_s = stride(s)
129
- s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
130
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
131
-
132
- grad_x = torch.empty_like(o, dtype=torch.float)
133
-
134
- triton_grid = lambda meta: [o_b,
135
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
136
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
137
-
138
- (_BlocksparseSoftmax.kernel_blocksparse_softmax_grad_x[triton_grid]
139
- (grad_output,
140
- o_b, o_b_s, o_r_s, o_c_s,
141
- o,
142
- o_b, o_b_s, o_r_s, o_c_s,
143
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
144
- s,
145
- s_b, s_b_s, s_r_s, s_c_s,
146
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
147
- sparsity_reverse_lut_s,
148
- grad_x,
149
- o_b, o_b_s, o_r_s, o_c_s,
150
- triton_block_size
151
- ))
152
-
153
- return grad_x, None, None, None, None, None
154
-
155
- @staticmethod
156
- @triton.jit
157
- def kernel_blocksparse_softmax(x,
158
- x_b, x_b_s, x_r_s, x_c_s,
159
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
160
- s, s_b, s_b_s, s_r_s, s_c_s,
161
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
162
- r_lut_s,
163
- o,
164
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
165
- # Get triton block indices
166
- pid_blk = tl.program_id(axis=0)
167
- pid_row = tl.program_id(axis=1)
168
- pid_col = tl.program_id(axis=2)
169
-
170
- # Get position of current sparsity block consisting of its batch and row index
171
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
172
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
173
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
174
-
175
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
176
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
177
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
178
-
179
- # Get reverse sparsity indices for s
180
- rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
181
- spa_row * s_l_s_r_s)
182
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
183
- rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
184
-
185
- if rev_idx_spa_s == -1:
186
- tl.device_assert(False)
187
- return
188
38
 
39
+ lut = softmax_build_lut(lut, sparsity_layout)
40
+
41
+ return BlksprsTensor(softmax_forward(x, sparsity_layout,
42
+ lut["sparsity_lut"],
43
+ lut["sparsity_reverse_lut_rws"],
44
+ sparsity_block_size))
45
+
46
+
47
+ @triton_op("blksprs::softmax", mutates_args={})
48
+ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
49
+ sparsity_lut: Tensor,
50
+ sparsity_reverse_lut_rws: Tensor,
51
+ sparsity_block_size: int) -> Tensor:
52
+ output = torch.empty_like(x)
53
+
54
+ x_b, x_r, x_c = x.size()
55
+ x_b_s, x_r_s, x_c_s = stride(x)
56
+ s_lut_r, s_lut_c = sparsity_lut.size()
57
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
58
+ o_b, o_r, o_c = output.size()
59
+
60
+ x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
61
+ flag_slice_only=True)
62
+ x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
63
+ x_exp = torch.exp(x_scaled)
64
+ x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
65
+ flag_slice_only=True)
66
+
67
+ s_b, s_r, s_c = x_exp_row_wise_sum.shape
68
+ s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
69
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
70
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
71
+
72
+ triton_grid = lambda meta: [o_b,
73
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
74
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
75
+
76
+ (wrap_triton(softmax_kernel)[triton_grid]
77
+ (x_exp,
78
+ x_b, x_b_s, x_r_s, x_c_s,
79
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
80
+ x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
81
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
82
+ sparsity_reverse_lut_rws,
83
+ output,
84
+ sparsity_block_size))
85
+
86
+ return output
87
+
88
+
89
+ def softmax_backward(ctx, grad_output):
90
+ o, sparsity_layout, sparsity_lut = ctx.saved_tensors
91
+ sparsity_block_size = ctx.sparsity_block_size
92
+
93
+ s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
94
+
95
+ sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
96
+ sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
97
+ (sparsity_layout_s_flat == 1) -
98
+ (1 * (sparsity_layout_s_flat == 0)))
99
+
100
+ o_b, o_r, o_c = o.size()
101
+ o_b_s, o_r_s, o_c_s = stride(o)
102
+ s_lut_r, s_lut_c = sparsity_lut.size()
103
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
104
+ s_b, s_r, s_c = s.size()
105
+ s_b_s, s_r_s, s_c_s = stride(s)
106
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
107
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
108
+
109
+ grad_x = torch.empty_like(o, dtype=torch.float)
110
+
111
+ triton_grid = lambda meta: [o_b,
112
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
113
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
114
+
115
+ (wrap_triton(softmax_kernel_grad)[triton_grid]
116
+ (grad_output,
117
+ o_b, o_b_s, o_r_s, o_c_s,
118
+ o,
119
+ o_b, o_b_s, o_r_s, o_c_s,
120
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
121
+ s,
122
+ s_b, s_b_s, s_r_s, s_c_s,
123
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
124
+ sparsity_reverse_lut_s,
125
+ grad_x,
126
+ o_b, o_b_s, o_r_s, o_c_s,
127
+ sparsity_block_size))
128
+
129
+ return grad_x, None, None, None, None, None
130
+
131
+
132
+ @triton.autotune(
133
+ configs=get_autotune_configs(),
134
+ key=[]
135
+ )
136
+ @triton.jit
137
+ def softmax_kernel(x,
138
+ x_b, x_b_s, x_r_s, x_c_s,
139
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
140
+ s, s_b, s_b_s, s_r_s, s_c_s,
141
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
142
+ r_lut_s,
143
+ o,
144
+ sparsity_block_size,
145
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
146
+ # Get triton block indices
147
+ pid_blk = tl.program_id(axis=0)
148
+ pid_row = tl.program_id(axis=1)
149
+ pid_col = tl.program_id(axis=2)
150
+
151
+ # Get valid triton block size
152
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
153
+
154
+ # Get position of current sparsity block consisting of its batch and row index
155
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
156
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
157
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
158
+
159
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
160
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
161
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
162
+
163
+ # Get reverse sparsity indices for s
164
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
165
+ spa_row * s_l_s_r_s)
166
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
167
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
168
+
169
+ if rev_idx_spa_s >= 0:
189
170
  # Load x block
190
171
  blk_x_idx = ((pid_blk * x_b_s) +
191
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
192
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
193
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
172
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
173
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
174
+ blk_x_msk = ((blk_x_idx >= 0 and
175
+ blk_x_idx < x_b * x_b_s) and
176
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
177
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
194
178
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
195
179
 
196
180
  # Load sum block
197
181
  blk_s_idx = (rev_idx_spa_s * s_b_s +
198
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
182
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
199
183
  (tl.arange(0, 1) * s_c_s)[None, :])
200
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
184
+ blk_s_msk = ((blk_s_idx >= 0 and
185
+ blk_s_idx < s_b * s_b_s) and
186
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
187
+ tl.arange(0, 1)[None, :] < val_tbs))
201
188
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
202
189
 
203
190
  # Compute softmax
@@ -206,65 +193,114 @@ class _BlocksparseSoftmax(torch.autograd.Function):
206
193
  # Store output
207
194
  tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
208
195
 
209
- @staticmethod
210
- @triton.jit
211
- def kernel_blocksparse_softmax_grad_x(g,
212
- g_b, g_b_s, g_r_s, g_c_s,
213
- x,
214
- x_b, x_b_s, x_r_s, x_c_s,
215
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
216
- s,
217
- s_b, s_b_s, s_r_s, s_c_s,
218
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
219
- r_lut_s,
220
- o,
221
- o_b, o_b_s, o_r_s, o_c_s,
222
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
223
- # Get triton block indices
224
- pid_blk = tl.program_id(axis=0)
225
- pid_row = tl.program_id(axis=1)
226
- pid_col = tl.program_id(axis=2)
227
-
228
- # Get position of current sparsity block consisting of its batch and row index
229
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
230
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
231
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
232
-
233
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
234
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
235
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
236
-
237
- rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
238
- spa_row * s_l_s_r_s)
239
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
240
- rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
241
-
242
- if rev_idx_spa_s == -1:
243
- tl.device_assert(False)
244
- return
245
196
 
197
+ @triton.autotune(
198
+ configs=get_autotune_configs(),
199
+ key=[]
200
+ )
201
+ @triton.jit
202
+ def softmax_kernel_grad(g,
203
+ g_b, g_b_s, g_r_s, g_c_s,
204
+ x,
205
+ x_b, x_b_s, x_r_s, x_c_s,
206
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
207
+ s,
208
+ s_b, s_b_s, s_r_s, s_c_s,
209
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
210
+ r_lut_s,
211
+ o,
212
+ o_b, o_b_s, o_r_s, o_c_s,
213
+ sparsity_block_size,
214
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
215
+ # Get triton block indices
216
+ pid_blk = tl.program_id(axis=0)
217
+ pid_row = tl.program_id(axis=1)
218
+ pid_col = tl.program_id(axis=2)
219
+
220
+ # Get valid triton block size
221
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
222
+
223
+ # Get position of current sparsity block consisting of its batch and row index
224
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
225
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
226
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
227
+
228
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
229
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
230
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
231
+
232
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
233
+ spa_row * s_l_s_r_s)
234
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
235
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
236
+
237
+ if rev_idx_spa_s >= 0:
246
238
  blk_s_idx = (rev_idx_spa_s * s_b_s +
247
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
239
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
248
240
  (tl.arange(0, 1) * s_c_s)[None, :])
249
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
241
+ blk_s_msk = ((blk_s_idx >= 0 and
242
+ blk_s_idx < s_b * s_b_s) and
243
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
244
+ tl.arange(0, 1)[None, :] < val_tbs))
250
245
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
251
246
 
252
247
  blk_g_idx = ((pid_blk * g_b_s) +
253
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
254
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
255
- blk_g_msk = (blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
248
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
249
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
250
+ blk_g_msk = ((blk_g_idx >= 0 and
251
+ blk_g_idx < g_b * g_b_s) and
252
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
253
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
256
254
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
257
255
 
258
256
  blk_x_idx = ((pid_blk * x_b_s) +
259
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
260
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
261
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
257
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
258
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
259
+ blk_x_msk = ((blk_x_idx >= 0 and
260
+ blk_x_idx < x_b * x_b_s) and
261
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
262
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
262
263
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
263
264
 
264
265
  buf = blk_x * (blk_g - blk_s)
265
266
 
266
267
  blk_o_idx = ((pid_blk * o_b_s) +
267
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
268
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
269
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
268
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
269
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
270
+ blk_o_msk = ((blk_o_idx >= 0 and
271
+ blk_o_idx < o_b * o_b_s) and
272
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
273
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
270
274
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
275
+
276
+
277
+ def softmax_build_lut(lut: dict, sparsity_layout: Tensor):
278
+ if lut is None:
279
+ lut = dict()
280
+
281
+ if "sparsity_lut" not in lut:
282
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
283
+ lut["sparsity_lut"] = sparsity_lut
284
+
285
+ if "sparsity_reverse_lut_rws" not in lut:
286
+ sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
287
+ sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
288
+ sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
289
+ (sparsity_layout_rws_flat == 1) -
290
+ (1 * (sparsity_layout_rws_flat == 0)))
291
+ lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
292
+
293
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
294
+
295
+ return lut
296
+
297
+
298
+ # noinspection PyUnusedLocal
299
+ def softmax_setup_context(ctx, inputs, output):
300
+ (_, sparsity_layout, sparsity_lut, _, sparsity_block_size) = inputs
301
+
302
+ ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
303
+ ctx.sparsity_block_size = sparsity_block_size
304
+
305
+
306
+ softmax_forward.register_autograd(softmax_backward, setup_context=softmax_setup_context)