blksprs 1.11__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,17 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library.triton import wrap_triton, triton_op
4
5
  from triton import language as tl
5
6
 
6
7
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
+ from blksprs.utils.tools import stride, get_autotune_configs
8
9
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
9
- validate_sparsity_block_size, validate_triton_block_size
10
+ validate_sparsity_block_size
10
11
 
11
12
 
12
13
  def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
13
- flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
14
+ flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
14
15
  """Computes the row-wise sum of a block-sparse tensor.
15
16
 
16
17
  Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
@@ -25,7 +26,6 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
25
26
  sparsity_block_size (int): The size of the sparsity blocks.
26
27
  flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
27
28
  (default ``False``).
28
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
29
29
 
30
30
  Returns:
31
31
  tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
@@ -39,7 +39,6 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
39
39
  validate_device(x)
40
40
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
41
41
  validate_sparsity_block_size(sparsity_block_size, x)
42
- validate_triton_block_size(triton_block_size, sparsity_block_size)
43
42
 
44
43
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
45
44
 
@@ -54,11 +53,19 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
54
53
  validate_contiguous(sparsity_layout, sparsity_lut,
55
54
  sparsity_layout_output, sparsity_reverse_lut_output)
56
55
 
57
- output = torch.zeros(size=(n_sparse_blocks_output,
58
- sparsity_block_size,
59
- 1 if flag_slice_only else sparsity_block_size),
60
- dtype=x.dtype,
61
- device=x.device)
56
+ return BlksprsTensor(row_wise_sum_forward(
57
+ x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output,
58
+ sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
59
+
60
+
61
+ @triton_op("blksprs::row_wise_sum", mutates_args={})
62
+ def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
63
+ sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
64
+ sparsity_block_size: int, n_sparse_blocks_output: int,
65
+ flag_slice_only: bool = False) -> Tensor:
66
+ output = torch.zeros(
67
+ size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
68
+ dtype=x.dtype, device=x.device)
62
69
 
63
70
  x_b, x_r, x_c = x.size()
64
71
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -69,14 +76,11 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
69
76
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
70
77
  s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
71
78
 
72
- if triton_block_size is None:
73
- triton_block_size = get_triton_block_size(sparsity_block_size)
74
-
75
79
  triton_grid = lambda meta: [x_b,
76
80
  triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
77
81
  triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
78
82
 
79
- (kernel_blocksparse_row_wise_sum[triton_grid]
83
+ (wrap_triton(row_wise_sum_kernel)[triton_grid]
80
84
  (x,
81
85
  x_b, x_b_s, x_r_s, x_c_s,
82
86
  sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
@@ -84,24 +88,34 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
84
88
  o_b, o_b_s, o_r_s,
85
89
  s_l_o_b, s_l_o_b_s, s_l_o_r_s,
86
90
  sparsity_reverse_lut_output,
87
- triton_block_size))
91
+ sparsity_block_size))
88
92
 
89
- return BlksprsTensor(output), sparsity_layout_output
93
+ return output
90
94
 
91
95
 
96
+ @triton.autotune(
97
+ configs=get_autotune_configs(),
98
+ key=[],
99
+ reset_to_zero=["o"]
100
+ )
92
101
  @triton.jit
93
- def kernel_blocksparse_row_wise_sum(x,
94
- x_b, x_b_s, x_r_s, x_c_s,
95
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
96
- o,
97
- o_b, o_b_s, o_r_s,
98
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
99
- r_lut_o,
100
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
102
+ def row_wise_sum_kernel(x,
103
+ x_b, x_b_s, x_r_s, x_c_s,
104
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
105
+ o,
106
+ o_b, o_b_s, o_r_s,
107
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
108
+ r_lut_o,
109
+ sparsity_block_size,
110
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
111
+ # Get triton block indices
101
112
  pid_blk = tl.program_id(axis=0)
102
113
  pid_row = tl.program_id(axis=1)
103
114
  pid_col = tl.program_id(axis=2)
104
115
 
116
+ # Get valid triton block size
117
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
118
+
105
119
  # Get position of current sparsity block consisting of its batch and row index
106
120
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
107
121
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -122,22 +136,28 @@ def kernel_blocksparse_row_wise_sum(x,
122
136
  return
123
137
 
124
138
  blk_idx = ((pid_blk * x_b_s) +
125
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
126
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
127
- blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
139
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
140
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
141
+ blk_msk = ((blk_idx >= 0 and
142
+ blk_idx < x_b * x_b_s) and
143
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
144
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
128
145
  blk = tl.load(x + blk_idx, mask=blk_msk)
129
146
 
130
147
  buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
131
148
 
132
149
  o_idx = (rev_idx_spa * o_b_s +
133
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
150
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
151
  (tl.arange(0, 1))[None, :])
135
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
152
+ o_msk = ((o_idx >= 0 and
153
+ o_idx < o_b * o_b_s) and
154
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
155
+ tl.arange(0, 1)[None, :] < val_tbs))
136
156
  tl.atomic_add(o + o_idx, buf, o_msk)
137
157
 
138
158
 
139
159
  def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
140
- flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
160
+ flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
141
161
  """Computes the row-wise max of a block-sparse tensor.
142
162
 
143
163
  Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
@@ -152,7 +172,6 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
152
172
  sparsity_block_size (int): The size of the sparsity blocks.
153
173
  flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
154
174
  (default ``False``).
155
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
156
175
 
157
176
  Returns:
158
177
  tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
@@ -166,7 +185,6 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
166
185
  validate_device(x)
167
186
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
168
187
  validate_sparsity_block_size(sparsity_block_size, x)
169
- validate_triton_block_size(triton_block_size, sparsity_block_size)
170
188
 
171
189
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
172
190
 
@@ -181,6 +199,16 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
181
199
  validate_contiguous(sparsity_layout, sparsity_lut,
182
200
  sparsity_layout_output, sparsity_reverse_lut_output)
183
201
 
202
+ return BlksprsTensor(
203
+ row_wise_max_forward(x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output, sparsity_block_size,
204
+ n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
205
+
206
+
207
+ @triton_op("blksprs::row_wise_max", mutates_args={})
208
+ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
209
+ sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
210
+ sparsity_block_size: int, n_sparse_blocks_output: int,
211
+ flag_slice_only: bool = False) -> Tensor:
184
212
  output = torch.full(size=(n_sparse_blocks_output,
185
213
  sparsity_block_size,
186
214
  1 if flag_slice_only else sparsity_block_size),
@@ -196,14 +224,11 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
196
224
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
197
225
  s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
198
226
 
199
- if triton_block_size is None:
200
- triton_block_size = get_triton_block_size(sparsity_block_size)
201
-
202
227
  triton_grid = lambda meta: [x_b,
203
228
  triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
204
229
  triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
205
230
 
206
- (kernel_blocksparse_row_wise_max[triton_grid]
231
+ (wrap_triton(row_wise_max_kernel)[triton_grid]
207
232
  (x,
208
233
  x_b, x_b_s, x_r_s, x_c_s,
209
234
  sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
@@ -211,24 +236,34 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
211
236
  o_b, o_b_s, o_r_s,
212
237
  s_l_o_b, s_l_o_b_s, s_l_o_r_s,
213
238
  sparsity_reverse_lut_output,
214
- triton_block_size))
239
+ sparsity_block_size))
215
240
 
216
- return BlksprsTensor(output), sparsity_layout_output
241
+ return output
217
242
 
218
243
 
244
+ @triton.autotune(
245
+ configs=get_autotune_configs(),
246
+ key=[],
247
+ restore_value=["o"]
248
+ )
219
249
  @triton.jit
220
- def kernel_blocksparse_row_wise_max(x,
221
- x_b, x_b_s, x_r_s, x_c_s,
222
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
223
- o,
224
- o_b, o_b_s, o_r_s,
225
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
226
- r_lut_o,
227
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
250
+ def row_wise_max_kernel(x,
251
+ x_b, x_b_s, x_r_s, x_c_s,
252
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
253
+ o,
254
+ o_b, o_b_s, o_r_s,
255
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
256
+ r_lut_o,
257
+ sparsity_block_size,
258
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
259
+ # Get triton block indices
228
260
  pid_blk = tl.program_id(axis=0)
229
261
  pid_row = tl.program_id(axis=1)
230
262
  pid_col = tl.program_id(axis=2)
231
263
 
264
+ # Get valid triton block size
265
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
266
+
232
267
  # Get position of current sparsity block consisting of its batch and row index
233
268
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
234
269
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -249,22 +284,28 @@ def kernel_blocksparse_row_wise_max(x,
249
284
  return
250
285
 
251
286
  blk_idx = ((pid_blk * x_b_s) +
252
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
253
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
254
- blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
287
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
288
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
289
+ blk_msk = ((blk_idx >= 0 and
290
+ blk_idx < x_b * x_b_s) and
291
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
292
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
255
293
  blk = tl.load(x + blk_idx, mask=blk_msk)
256
294
 
257
295
  buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
258
296
 
259
297
  o_idx = (rev_idx_spa * o_b_s +
260
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
298
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
261
299
  (tl.arange(0, 1))[None, :])
262
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
300
+ o_msk = ((o_idx >= 0 and
301
+ o_idx < o_b * o_b_s) and
302
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
303
+ tl.arange(0, 1)[None, :] < val_tbs))
263
304
  tl.atomic_max(o + o_idx, buf, o_msk)
264
305
 
265
306
 
266
307
  def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
267
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
308
+ sparsity_block_size: int) -> BlksprsTensor:
268
309
  """For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
269
310
 
270
311
  Args:
@@ -272,7 +313,6 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
272
313
  sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
273
314
  y (BlksprsTensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
274
315
  sparsity_block_size (int): The size of the sparsity blocks.
275
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
276
316
 
277
317
  Returns:
278
318
  BlksprsTensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
@@ -284,9 +324,8 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
284
324
  validate_device(x)
285
325
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
286
326
  validate_sparsity_block_size(sparsity_block_size, x)
287
- validate_triton_block_size(triton_block_size, sparsity_block_size)
288
327
 
289
- sparsity_lut = torch.nonzero(sparsity_layout_x).contiguous()
328
+ sparsity_lut_x = torch.nonzero(sparsity_layout_x).contiguous()
290
329
 
291
330
  sparsity_layout_rwm, _ = torch.max(sparsity_layout_x, dim=-1, keepdim=True)
292
331
  sparsity_layout_rwm_flat = sparsity_layout_rwm.reshape(-1)
@@ -294,24 +333,37 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
294
333
  (sparsity_layout_rwm_flat == 1) -
295
334
  (1 * (sparsity_layout_rwm_flat == 0)))
296
335
 
297
- validate_contiguous(sparsity_layout_x, sparsity_lut, sparsity_reverse_lut_rwm)
336
+ validate_contiguous(sparsity_layout_x, sparsity_lut_x, sparsity_reverse_lut_rwm)
337
+
338
+ return BlksprsTensor(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
339
+ sparsity_reverse_lut_rwm, y, sparsity_block_size))
298
340
 
341
+
342
+ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
343
+ sparsity_block_size: int) -> BlksprsTensor:
344
+ """Wrapper for ``row_wise_add`` with negated y.
345
+
346
+ """
347
+ return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size)
348
+
349
+
350
+ @triton_op("blksprs::row_wise_add", mutates_args={})
351
+ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
352
+ sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
353
+ y: Tensor, sparsity_block_size: int) -> Tensor:
299
354
  output = torch.empty_like(x)
300
355
 
301
356
  x_b, x_r, x_c = x.size()
302
357
  x_b_s, x_r_s, x_c_s = stride(x)
303
- s_lut_r, s_lut_c = sparsity_lut.size()
304
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
358
+ s_lut_r, s_lut_c = sparsity_lut_x.size()
359
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
305
360
  y_b, y_r, y_c = y.size()
306
361
  y_b_s, y_r_s, y_c_s = stride(y)
307
- s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
308
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
362
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
363
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
309
364
  o_b, o_r, o_c = output.size()
310
365
  o_b_s, o_r_s, o_c_s = stride(output)
311
366
 
312
- if triton_block_size is None:
313
- triton_block_size = get_triton_block_size(sparsity_block_size)
314
-
315
367
  triton_grid = lambda meta: [o_b,
316
368
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
317
369
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
@@ -319,49 +371,48 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
319
371
  (kernel_blocksparse_row_wise_add[triton_grid]
320
372
  (x,
321
373
  x_b, x_b_s, x_r_s, x_c_s,
322
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
374
+ sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
323
375
  y, y_b, y_b_s, y_r_s, y_c_s,
324
376
  s_l_y_b, s_l_y_b_s, s_l_y_r_s,
325
- sparsity_reverse_lut_rwm,
377
+ sparsity_reverse_x_lut_rwm,
326
378
  output,
327
379
  o_b, o_b_s, o_r_s, o_c_s,
328
- triton_block_size
329
- ))
330
-
331
- return BlksprsTensor(output)
332
-
333
-
334
- def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
335
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
336
- """Wrapper for ``row_wise_add`` with negated y.
380
+ sparsity_block_size))
337
381
 
338
- """
339
- return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size, triton_block_size)
382
+ return output
340
383
 
341
384
 
385
+ @triton.autotune(
386
+ configs=get_autotune_configs(),
387
+ key=[]
388
+ )
342
389
  @triton.jit
343
390
  def kernel_blocksparse_row_wise_add(x,
344
391
  x_b, x_b_s, x_r_s, x_c_s,
345
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
392
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
346
393
  y, y_b, y_b_s, y_r_s, y_c_s,
347
394
  s_l_y_b, s_l_y_b_s, s_l_y_r_s,
348
395
  r_lut_y,
349
396
  o,
350
397
  o_b, o_b_s, o_r_s, o_c_s,
398
+ sparsity_block_size,
351
399
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
352
400
  # Get triton block indices
353
401
  pid_blk = tl.program_id(axis=0)
354
402
  pid_row = tl.program_id(axis=1)
355
403
  pid_col = tl.program_id(axis=2)
356
404
 
405
+ # Get valid triton block size
406
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
407
+
357
408
  # Get position of current sparsity block consisting of its batch and row index
358
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
359
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
360
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
409
+ spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
410
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
411
+ spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
361
412
 
362
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
363
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
364
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
413
+ spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
414
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
415
+ spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
365
416
 
366
417
  # Get reverse sparsity indices for s
367
418
  rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
@@ -375,16 +426,22 @@ def kernel_blocksparse_row_wise_add(x,
375
426
 
376
427
  # Load x block
377
428
  blk_x_idx = ((pid_blk * x_b_s) +
378
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
379
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
380
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
429
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
430
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
431
+ blk_x_msk = ((blk_x_idx >= 0 and
432
+ blk_x_idx < x_b * x_b_s) and
433
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
434
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
381
435
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
382
436
 
383
437
  # Load sum block
384
438
  blk_s_idx = (rev_idx_spa_s * y_b_s +
385
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
439
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
386
440
  (tl.arange(0, 1) * y_c_s)[None, :])
387
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < y_b * y_b_s)
441
+ blk_s_msk = ((blk_s_idx >= 0 and
442
+ blk_s_idx < y_b * y_b_s) and
443
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
444
+ tl.arange(0, 1)[None, :] < val_tbs))
388
445
  blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
389
446
 
390
447
  # Compute exp
@@ -392,7 +449,10 @@ def kernel_blocksparse_row_wise_add(x,
392
449
 
393
450
  # Store block
394
451
  blk_o_idx = ((pid_blk * o_b_s) +
395
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
396
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
397
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
452
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
453
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
454
+ blk_o_msk = ((blk_o_idx >= 0 and
455
+ blk_o_idx < o_b * o_b_s) and
456
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
457
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
398
458
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)