blksprs 2.0rc4__py3-none-any.whl → 2.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,8 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import stride, get_autotune_configs
7
+ from blksprs.utils.tools import stride
8
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
9
  from blksprs.utils.validation import validate_dimensions, validate_device, \
9
10
  validate_contiguous
10
11
 
@@ -47,6 +48,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
47
48
  triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
48
49
  triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
49
50
 
51
+ # TODO wrap
50
52
  (build_distribution_layout_kernel[triton_grid]
51
53
  (indices,
52
54
  i_b, i_b_s, i_r_s, i_c_s,
@@ -62,7 +64,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
62
64
 
63
65
  @triton.autotune(
64
66
  configs=get_autotune_configs(),
65
- key=[],
67
+ key=["sparsity_block_size"],
68
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
66
69
  reset_to_zero=["o"]
67
70
  )
68
71
  @triton.jit
@@ -80,9 +83,6 @@ def build_distribution_layout_kernel(i,
80
83
  pid_row = tl.program_id(axis=1)
81
84
  pid_col = tl.program_id(axis=2)
82
85
 
83
- # Get valid triton block size
84
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
85
-
86
86
  # Get position of current sparsity block consisting of its batch, row, and column index
87
87
  spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
88
88
  spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
@@ -97,12 +97,10 @@ def build_distribution_layout_kernel(i,
97
97
  spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
98
98
 
99
99
  blk_i_idx = (pid_blk * i_b_s +
100
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
- blk_i_msk = ((blk_i_idx >= 0 and
103
- blk_i_idx < i_b * i_b_s) and
104
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
105
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
100
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
+ blk_i_msk = (blk_i_idx >= 0 and
103
+ blk_i_idx < i_b * i_b_s)
106
104
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
107
105
 
108
106
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -120,8 +118,6 @@ def build_distribution_layout_kernel(i,
120
118
  blk_o_idx = ((dst_bat_idx * o_b_s) +
121
119
  (dst_row_idx * o_r_s) +
122
120
  (dst_col_idx * o_c_s))
123
- blk_o_msk = ((blk_o_idx >= 0 and
124
- blk_o_idx < o_b * o_b_s) and
125
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
126
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
121
+ blk_o_msk = (blk_o_idx >= 0 and
122
+ blk_o_idx < o_b * o_b_s)
127
123
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -7,7 +7,8 @@ from torch._library.triton import wrap_triton
7
7
  from triton import language as tl
8
8
 
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import stride, get_autotune_configs
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
11
12
  from blksprs.utils.validation import validate_dimensions, validate_device, \
12
13
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
14
 
@@ -37,10 +38,11 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
37
38
  o_b_s, o_r_s, o_c_s = stride(output)
38
39
 
39
40
  triton_grid = lambda meta: [x_b,
40
- triton.cdiv(x_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
41
- triton.cdiv(x_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
41
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
42
43
 
43
- (wrap_triton(build_sparsity_layout_kernel)[triton_grid]
44
+ # TODO wrap
45
+ (build_sparsity_layout_kernel[triton_grid]
44
46
  (x,
45
47
  x_b, x_b_s, x_r_s, x_c_s,
46
48
  output,
@@ -52,7 +54,8 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
52
54
 
53
55
  @triton.autotune(
54
56
  configs=get_autotune_configs(),
55
- key=[],
57
+ key=["sparsity_block_size"],
58
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
56
59
  reset_to_zero=["o"]
57
60
  )
58
61
  @triton.jit
@@ -67,24 +70,19 @@ def build_sparsity_layout_kernel(x,
67
70
  pid_row = tl.program_id(axis=1)
68
71
  pid_col = tl.program_id(axis=2)
69
72
 
70
- # Get valid triton block size
71
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
72
-
73
73
  # Load x values
74
74
  blk_x_idx = (pid_bat * x_b_s +
75
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
76
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
77
- blk_x_msk = ((blk_x_idx >= 0 and
78
- blk_x_idx < x_b * x_b_s) and
79
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
80
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
75
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
76
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
77
+ blk_x_msk = (blk_x_idx >= 0 and
78
+ blk_x_idx < x_b * x_b_s)
81
79
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
82
80
 
83
81
  # Store sparsity layout value
84
82
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
85
83
  blk_o_idx = (pid_bat * o_b_s +
86
- (((pid_row * val_tbs) // sparsity_block_size) * o_r_s +
87
- ((pid_col * val_tbs) // sparsity_block_size) * o_c_s))
84
+ (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
85
+ ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
88
86
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
89
87
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
90
88
 
@@ -129,10 +127,11 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
129
127
  o_b_s, o_r_s, o_c_s = stride(output)
130
128
 
131
129
  triton_grid = lambda meta: [x_b,
132
- triton.cdiv(x_r, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"])),
133
- triton.cdiv(x_c, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"]))]
130
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
131
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
134
132
 
135
- (wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
133
+ # TODO wrap
134
+ (build_sparsity_layout_adaption_kernel[triton_grid]
136
135
  (x,
137
136
  x_b, x_b_s, x_r_s, x_c_s,
138
137
  sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
@@ -146,7 +145,8 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
146
145
 
147
146
  @triton.autotune(
148
147
  configs=get_autotune_configs(),
149
- key=[],
148
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
149
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
150
150
  reset_to_zero=["o"]
151
151
  )
152
152
  @triton.jit
@@ -163,9 +163,6 @@ def build_sparsity_layout_adaption_kernel(x,
163
163
  pid_row = tl.program_id(axis=1)
164
164
  pid_col = tl.program_id(axis=2)
165
165
 
166
- # Get valid triton block size
167
- val_tbs = min(sparsity_block_size_to, TRITON_BLOCK_SIZE)
168
-
169
166
  # Get sparsity index of current output block consisting of its batch, row, and column index
170
167
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
171
168
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -181,20 +178,18 @@ def build_sparsity_layout_adaption_kernel(x,
181
178
 
182
179
  # Load x values
183
180
  blk_x_idx = ((pid_blk * x_b_s) +
184
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
185
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
186
- blk_x_msk = ((blk_x_idx >= 0 and
187
- blk_x_idx < x_b * x_b_s) and
188
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < sparsity_block_size_from and
189
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < sparsity_block_size_from))
181
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
182
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
183
+ blk_x_msk = (blk_x_idx >= 0 and
184
+ blk_x_idx < x_b * x_b_s)
190
185
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
191
186
 
192
187
  # Store sparsity layout value
193
188
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
194
189
  blk_o_idx = ((spa_bat * o_b_s) +
195
- (((pid_row * val_tbs + spa_row * sparsity_block_size_from)
190
+ (((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size_from)
196
191
  // sparsity_block_size_to) * o_r_s) +
197
- (((pid_col * val_tbs + spa_col * sparsity_block_size_from)
192
+ (((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
198
193
  // sparsity_block_size_to) * o_c_s))
199
194
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
200
195
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
blksprs/ops/conversion.py CHANGED
@@ -6,7 +6,8 @@ from triton import language as tl
6
6
 
7
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
12
13
 
@@ -86,7 +87,8 @@ def to_sparse_backward(ctx, grad_output):
86
87
 
87
88
  @triton.autotune(
88
89
  configs=get_autotune_configs(),
89
- key=[],
90
+ key=["sparsity_block_size"],
91
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
90
92
  reset_to_zero=["o"]
91
93
  )
92
94
  @triton.jit
@@ -102,9 +104,6 @@ def to_sparse_kernel(x,
102
104
  pid_row = tl.program_id(axis=1)
103
105
  pid_col = tl.program_id(axis=2)
104
106
 
105
- # Get valid triton block size
106
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
107
-
108
107
  # Get sparsity index of current output block consisting of its batch, row, and column index
109
108
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
110
109
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -120,24 +119,20 @@ def to_sparse_kernel(x,
120
119
 
121
120
  # Load block from dense tensor
122
121
  blk_d_idx = (spa_bat * x_b_s +
123
- ((pid_row * val_tbs + spa_row * sparsity_block_size +
122
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
124
123
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
125
- ((pid_col * val_tbs + spa_col * sparsity_block_size +
124
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
126
125
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
127
- blk_d_msk = ((blk_d_idx >= 0 and
128
- blk_d_idx < x_b * x_b_s) and
129
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
130
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
126
+ blk_d_msk = (blk_d_idx >= 0 and
127
+ blk_d_idx < x_b * x_b_s)
131
128
  blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
132
129
 
133
130
  # Store block in sparse tensor
134
131
  blk_o_idx = ((pid_blk * o_b_s) +
135
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
136
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
137
- blk_o_msk = ((blk_o_idx >= 0 and
138
- blk_o_idx < (pid_blk + 1) * o_b_s) and
139
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
140
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
132
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
133
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
134
+ blk_o_msk = (blk_o_idx >= 0 and
135
+ blk_o_idx < (pid_blk + 1) * o_b_s)
141
136
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
142
137
 
143
138
 
@@ -228,8 +223,8 @@ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
228
223
  o_b_s, o_r_s, o_c_s = stride(output)
229
224
 
230
225
  triton_grid = lambda meta: [o_b,
231
- triton.cdiv(o_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
232
- triton.cdiv(o_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
226
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
227
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
233
228
 
234
229
  (wrap_triton(to_dense_kernel)[triton_grid]
235
230
  (x,
@@ -252,7 +247,8 @@ def to_dense_backward(ctx, grad_output):
252
247
 
253
248
  @triton.autotune(
254
249
  configs=get_autotune_configs(),
255
- key=[],
250
+ key=["sparsity_block_size"],
251
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
256
252
  restore_value=["o"]
257
253
  )
258
254
  @triton.jit
@@ -269,12 +265,9 @@ def to_dense_kernel(x,
269
265
  pid_row = tl.program_id(axis=1)
270
266
  pid_col = tl.program_id(axis=2)
271
267
 
272
- # Get valid triton block size
273
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
274
-
275
268
  # Get sparsity index of current block
276
- spa_row = (pid_row * val_tbs) // sparsity_block_size
277
- spa_col = (pid_col * val_tbs) // sparsity_block_size
269
+ spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
270
+ spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
278
271
 
279
272
  # Get reverse sparsity index for current block
280
273
  rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
@@ -284,22 +277,18 @@ def to_dense_kernel(x,
284
277
  # If block is present commence operations
285
278
  if rev_idx_spa >= 0:
286
279
  blk_idx = (rev_idx_spa * x_b_s +
287
- (((pid_row % (sparsity_block_size // val_tbs)) * val_tbs +
280
+ (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
288
281
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
289
- (((pid_col % (sparsity_block_size // val_tbs)) * val_tbs +
282
+ (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
290
283
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
291
- blk_msk = ((blk_idx >= 0 and
292
- blk_idx < x_b * x_b_s) and
293
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
294
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
295
286
  blk = tl.load(x + blk_idx, mask=blk_msk)
296
287
 
297
288
  o_idx = (pid_blk * o_b_s +
298
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
299
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
300
- o_msk = ((o_idx >= 0 and o_idx < o_b * o_b_s) and
301
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
302
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
289
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
290
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
291
+ o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
303
292
  tl.store(o + o_idx, blk, o_msk)
304
293
 
305
294
 
@@ -403,12 +392,11 @@ def adapt_layout_forward(x: Tensor,
403
392
  s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
404
393
 
405
394
  triton_grid = lambda meta: [o_b,
406
- triton.cdiv(o_r, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
407
- meta["TRITON_BLOCK_SIZE"])),
408
- triton.cdiv(o_c, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
409
- meta["TRITON_BLOCK_SIZE"]))]
395
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
396
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
410
397
 
411
- (wrap_triton(adapt_layout_kernel)[triton_grid]
398
+ # TODO wrap
399
+ (adapt_layout_kernel[triton_grid]
412
400
  (x,
413
401
  x_b, x_b_s, x_r_s, x_c_s,
414
402
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
@@ -434,7 +422,8 @@ def adapt_layout_backward(ctx, grad_output):
434
422
 
435
423
  @triton.autotune(
436
424
  configs=get_autotune_configs(),
437
- key=[],
425
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
426
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
438
427
  reset_to_zero=["o"]
439
428
  )
440
429
  @triton.jit
@@ -453,9 +442,6 @@ def adapt_layout_kernel(x,
453
442
  pid_row = tl.program_id(axis=1)
454
443
  pid_col = tl.program_id(axis=2)
455
444
 
456
- # Get valid triton block size (Triton can only handle 2-valued min)
457
- val_tbs = min(min(sparsity_block_size_from, sparsity_block_size_to), TRITON_BLOCK_SIZE)
458
-
459
445
  # Get position of current sparsity block consisting of its batch, row, and column index
460
446
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
461
447
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -471,8 +457,8 @@ def adapt_layout_kernel(x,
471
457
 
472
458
  # Get equivalent sparsity block in from layout
473
459
  spa_bat_x = spa_bat_o
474
- spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * val_tbs) // sparsity_block_size_from
475
- spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * val_tbs) // sparsity_block_size_from
460
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
461
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
476
462
 
477
463
  # Get reverse sparsity indices for x
478
464
  rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
@@ -484,29 +470,25 @@ def adapt_layout_kernel(x,
484
470
  # If block is present commence operations
485
471
  if rev_idx_spa_x >= 0:
486
472
  # Calculate triton block size shifts
487
- shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * val_tbs)
488
- % sparsity_block_size_from) // val_tbs
489
- shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * val_tbs)
490
- % sparsity_block_size_from) // val_tbs
473
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
474
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
475
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
476
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
491
477
 
492
478
  # Load x values
493
479
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
494
- ((shift_row_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
495
- ((shift_col_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
496
- blk_x_msk = ((blk_x_idx >= 0 and
497
- blk_x_idx < x_b * x_b_s) and
498
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
499
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
480
+ ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
481
+ ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
482
+ blk_x_msk = (blk_x_idx >= 0 and
483
+ blk_x_idx < x_b * x_b_s)
500
484
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
501
485
 
502
486
  # Store output
503
487
  blk_o_idx = ((pid_blk * o_b_s) +
504
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
505
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
506
- blk_o_msk = ((blk_o_idx >= 0 and
507
- blk_o_idx < o_b * o_b_s) and
508
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
509
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
488
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
489
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
490
+ blk_o_msk = (blk_o_idx >= 0 and
491
+ blk_o_idx < o_b * o_b_s)
510
492
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
511
493
 
512
494
 
@@ -6,7 +6,8 @@ from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size
12
13
 
@@ -100,7 +101,8 @@ def gather_backward(ctx, grad_output):
100
101
 
101
102
  @triton.autotune(
102
103
  configs=get_autotune_configs(),
103
- key=[],
104
+ key=["sparsity_block_size"],
105
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
104
106
  reset_to_zero=["o"]
105
107
  )
106
108
  @triton.jit
@@ -121,9 +123,6 @@ def gather_kernel(x,
121
123
  pid_row = tl.program_id(axis=1)
122
124
  pid_col = tl.program_id(axis=2)
123
125
 
124
- # Get valid triton block size
125
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
126
-
127
126
  # Get position of current sparsity block consisting of its batch, row, and column index
128
127
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
129
128
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -139,12 +138,10 @@ def gather_kernel(x,
139
138
 
140
139
  # Load index values
141
140
  blk_i_idx = ((pid_blk * i_b_s) +
142
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
143
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
144
- blk_i_msk = ((blk_i_idx >= 0 and
145
- blk_i_idx < i_b * i_b_s) and
146
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
147
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
141
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
142
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
143
+ blk_i_msk = (blk_i_idx >= 0 and
144
+ blk_i_idx < i_b * i_b_s)
148
145
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
149
146
 
150
147
  # Get indices of sparsity blocks and positions within the blocks
@@ -154,9 +151,9 @@ def gather_kernel(x,
154
151
  rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
155
152
  rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
156
153
  rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
157
- dst_row_x = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
154
+ dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
158
155
  .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
159
- dst_col_x = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
156
+ dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
160
157
  .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
161
158
  if dim == 0:
162
159
  rev_dst_bat_x = blk_i
@@ -171,32 +168,26 @@ def gather_kernel(x,
171
168
  rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
172
169
  (rev_dst_row_x * s_l_x_r_s) +
173
170
  (rev_dst_col_x * s_l_x_c_s))
174
- rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0 and
175
- rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s) and
176
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
177
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
171
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
172
+ rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
178
173
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
179
174
 
180
175
  # Load x values
181
176
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
182
177
  dst_row_x +
183
178
  dst_col_x)
184
- blk_x_msk = (((blk_x_idx >= 0 and
185
- blk_x_idx < x_b * x_b_s) and
186
- rev_idx_spa_x_msk != -1) and
187
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
188
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
179
+ blk_x_msk = ((blk_x_idx >= 0 and
180
+ blk_x_idx < x_b * x_b_s) and
181
+ rev_idx_spa_x_msk != -1)
189
182
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
190
183
 
191
184
  # Store output
192
185
  blk_o_idx = ((pid_blk * o_b_s) +
193
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
194
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
195
- blk_o_msk = (((blk_o_idx >= 0 and
196
- blk_o_idx < o_b * o_b_s) and
197
- rev_idx_spa_x_msk != -1) and
198
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
199
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
186
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
187
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
188
+ blk_o_msk = ((blk_o_idx >= 0 and
189
+ blk_o_idx < o_b * o_b_s) and
190
+ rev_idx_spa_x_msk != -1)
200
191
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
201
192
 
202
193
 
@@ -249,7 +240,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
249
240
  reduce_op="none", lut=lut)
250
241
 
251
242
 
252
- @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
243
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
253
244
  def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
254
245
  dim: int,
255
246
  idx: BlksprsTensor,
@@ -357,7 +348,8 @@ def scatter_reduce_backward(ctx, grad_output):
357
348
 
358
349
  @triton.autotune(
359
350
  configs=get_autotune_configs(),
360
- key=[],
351
+ key=["sparsity_block_size"],
352
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
361
353
  reset_to_zero=["o"]
362
354
  )
363
355
  @triton.jit
@@ -379,9 +371,6 @@ def scatter_reduce_kernel(x,
379
371
  pid_row = tl.program_id(axis=1)
380
372
  pid_col = tl.program_id(axis=2)
381
373
 
382
- # Get valid triton block size
383
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
384
-
385
374
  # Get position of current sparsity block consisting of its batch, row, and column index
386
375
  spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
387
376
  spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
@@ -397,22 +386,18 @@ def scatter_reduce_kernel(x,
397
386
 
398
387
  # Load x values
399
388
  blk_x_idx = ((pid_blk * x_b_s) +
400
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
401
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
402
- blk_x_msk = ((blk_x_idx >= 0 and
403
- blk_x_idx < x_b * x_b_s) and
404
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
405
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
389
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
390
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
391
+ blk_x_msk = (blk_x_idx >= 0 and
392
+ blk_x_idx < x_b * x_b_s)
406
393
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
407
394
 
408
395
  # Load index values
409
396
  blk_i_idx = ((pid_blk * i_b_s) +
410
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
411
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
412
- blk_i_msk = ((blk_i_idx >= 0 and
413
- blk_i_idx < i_b * i_b_s) and
414
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
415
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
397
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
398
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
399
+ blk_i_msk = (blk_i_idx >= 0 and
400
+ blk_i_idx < i_b * i_b_s)
416
401
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
417
402
 
418
403
  # Get indices of sparsity blocks and positions within the blocks
@@ -422,9 +407,9 @@ def scatter_reduce_kernel(x,
422
407
  rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
423
408
  rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
424
409
  rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
425
- dst_row_o = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
410
+ dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
426
411
  .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
427
- dst_col_o = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
412
+ dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
428
413
  .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
429
414
  if dim == 0:
430
415
  rev_dst_bat_o = blk_i
@@ -439,21 +424,17 @@ def scatter_reduce_kernel(x,
439
424
  rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
440
425
  (rev_dst_row_o * s_l_o_r_s) +
441
426
  (rev_dst_col_o * s_l_o_c_s))
442
- rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0 and
443
- rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s) and
444
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
445
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
427
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
428
+ rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
446
429
  rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
447
430
 
448
431
  # Store output
449
432
  blk_o_idx = ((rev_idx_spa_o * o_b_s) +
450
433
  dst_row_o +
451
434
  dst_col_o)
452
- blk_o_msk = (((blk_o_idx >= 0 and
453
- blk_o_idx < o_b * o_b_s) and
454
- rev_idx_spa_o_msk != -1) and
455
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
456
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
435
+ blk_o_msk = ((blk_o_idx >= 0 and
436
+ blk_o_idx < o_b * o_b_s) and
437
+ rev_idx_spa_o_msk != -1)
457
438
 
458
439
  if reduce_op_ind == 0:
459
440
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)