blksprs 2.0rc3__py3-none-any.whl → 2.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,8 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import stride, get_autotune_configs
7
+ from blksprs.utils.tools import stride
8
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
9
  from blksprs.utils.validation import validate_dimensions, validate_device, \
9
10
  validate_contiguous
10
11
 
@@ -47,6 +48,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
47
48
  triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
48
49
  triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
49
50
 
51
+ # TODO wrap
50
52
  (build_distribution_layout_kernel[triton_grid]
51
53
  (indices,
52
54
  i_b, i_b_s, i_r_s, i_c_s,
@@ -62,7 +64,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
62
64
 
63
65
  @triton.autotune(
64
66
  configs=get_autotune_configs(),
65
- key=[],
67
+ key=["sparsity_block_size"],
68
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
66
69
  reset_to_zero=["o"]
67
70
  )
68
71
  @triton.jit
@@ -80,9 +83,6 @@ def build_distribution_layout_kernel(i,
80
83
  pid_row = tl.program_id(axis=1)
81
84
  pid_col = tl.program_id(axis=2)
82
85
 
83
- # Get valid triton block size
84
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
85
-
86
86
  # Get position of current sparsity block consisting of its batch, row, and column index
87
87
  spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
88
88
  spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
@@ -97,12 +97,10 @@ def build_distribution_layout_kernel(i,
97
97
  spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
98
98
 
99
99
  blk_i_idx = (pid_blk * i_b_s +
100
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
- blk_i_msk = ((blk_i_idx >= 0 and
103
- blk_i_idx < i_b * i_b_s) and
104
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
105
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
100
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
+ blk_i_msk = (blk_i_idx >= 0 and
103
+ blk_i_idx < i_b * i_b_s)
106
104
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
107
105
 
108
106
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -120,8 +118,6 @@ def build_distribution_layout_kernel(i,
120
118
  blk_o_idx = ((dst_bat_idx * o_b_s) +
121
119
  (dst_row_idx * o_r_s) +
122
120
  (dst_col_idx * o_c_s))
123
- blk_o_msk = ((blk_o_idx >= 0 and
124
- blk_o_idx < o_b * o_b_s) and
125
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
126
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
121
+ blk_o_msk = (blk_o_idx >= 0 and
122
+ blk_o_idx < o_b * o_b_s)
127
123
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -7,7 +7,8 @@ from torch._library.triton import wrap_triton
7
7
  from triton import language as tl
8
8
 
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import stride, get_autotune_configs
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
11
12
  from blksprs.utils.validation import validate_dimensions, validate_device, \
12
13
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
14
 
@@ -37,10 +38,11 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
37
38
  o_b_s, o_r_s, o_c_s = stride(output)
38
39
 
39
40
  triton_grid = lambda meta: [x_b,
40
- triton.cdiv(x_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
41
- triton.cdiv(x_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
41
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
42
43
 
43
- (wrap_triton(build_sparsity_layout_kernel)[triton_grid]
44
+ # TODO wrap
45
+ (build_sparsity_layout_kernel[triton_grid]
44
46
  (x,
45
47
  x_b, x_b_s, x_r_s, x_c_s,
46
48
  output,
@@ -52,7 +54,8 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
52
54
 
53
55
  @triton.autotune(
54
56
  configs=get_autotune_configs(),
55
- key=[],
57
+ key=["sparsity_block_size"],
58
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
56
59
  reset_to_zero=["o"]
57
60
  )
58
61
  @triton.jit
@@ -67,24 +70,19 @@ def build_sparsity_layout_kernel(x,
67
70
  pid_row = tl.program_id(axis=1)
68
71
  pid_col = tl.program_id(axis=2)
69
72
 
70
- # Get valid triton block size
71
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
72
-
73
73
  # Load x values
74
74
  blk_x_idx = (pid_bat * x_b_s +
75
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
76
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
77
- blk_x_msk = ((blk_x_idx >= 0 and
78
- blk_x_idx < x_b * x_b_s) and
79
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
80
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
75
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
76
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
77
+ blk_x_msk = (blk_x_idx >= 0 and
78
+ blk_x_idx < x_b * x_b_s)
81
79
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
82
80
 
83
81
  # Store sparsity layout value
84
82
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
85
83
  blk_o_idx = (pid_bat * o_b_s +
86
- (((pid_row * val_tbs) // sparsity_block_size) * o_r_s +
87
- ((pid_col * val_tbs) // sparsity_block_size) * o_c_s))
84
+ (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
85
+ ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
88
86
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
89
87
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
90
88
 
@@ -129,10 +127,11 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
129
127
  o_b_s, o_r_s, o_c_s = stride(output)
130
128
 
131
129
  triton_grid = lambda meta: [x_b,
132
- triton.cdiv(x_r, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"])),
133
- triton.cdiv(x_c, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"]))]
130
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
131
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
134
132
 
135
- (wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
133
+ # TODO wrap
134
+ (build_sparsity_layout_adaption_kernel[triton_grid]
136
135
  (x,
137
136
  x_b, x_b_s, x_r_s, x_c_s,
138
137
  sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
@@ -146,7 +145,8 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
146
145
 
147
146
  @triton.autotune(
148
147
  configs=get_autotune_configs(),
149
- key=[],
148
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
149
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
150
150
  reset_to_zero=["o"]
151
151
  )
152
152
  @triton.jit
@@ -163,9 +163,6 @@ def build_sparsity_layout_adaption_kernel(x,
163
163
  pid_row = tl.program_id(axis=1)
164
164
  pid_col = tl.program_id(axis=2)
165
165
 
166
- # Get valid triton block size
167
- val_tbs = min(sparsity_block_size_to, TRITON_BLOCK_SIZE)
168
-
169
166
  # Get sparsity index of current output block consisting of its batch, row, and column index
170
167
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
171
168
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -181,20 +178,18 @@ def build_sparsity_layout_adaption_kernel(x,
181
178
 
182
179
  # Load x values
183
180
  blk_x_idx = ((pid_blk * x_b_s) +
184
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
185
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
186
- blk_x_msk = ((blk_x_idx >= 0 and
187
- blk_x_idx < x_b * x_b_s) and
188
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < sparsity_block_size_from and
189
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < sparsity_block_size_from))
181
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
182
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
183
+ blk_x_msk = (blk_x_idx >= 0 and
184
+ blk_x_idx < x_b * x_b_s)
190
185
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
191
186
 
192
187
  # Store sparsity layout value
193
188
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
194
189
  blk_o_idx = ((spa_bat * o_b_s) +
195
- (((pid_row * val_tbs + spa_row * sparsity_block_size_from)
190
+ (((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size_from)
196
191
  // sparsity_block_size_to) * o_r_s) +
197
- (((pid_col * val_tbs + spa_col * sparsity_block_size_from)
192
+ (((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
198
193
  // sparsity_block_size_to) * o_c_s))
199
194
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
200
195
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
blksprs/ops/conversion.py CHANGED
@@ -6,7 +6,8 @@ from triton import language as tl
6
6
 
7
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
12
13
 
@@ -54,7 +55,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
54
55
  @triton_op("blksprs::to_sparse", mutates_args={})
55
56
  def to_sparse_forward(x: Tensor, _: Tensor,
56
57
  sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
57
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
58
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
58
59
  dtype=x.dtype, device=x.device)
59
60
 
60
61
  x_b, x_r, x_c = x.size()
@@ -86,7 +87,9 @@ def to_sparse_backward(ctx, grad_output):
86
87
 
87
88
  @triton.autotune(
88
89
  configs=get_autotune_configs(),
89
- key=[],
90
+ key=["sparsity_block_size"],
91
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
92
+ reset_to_zero=["o"]
90
93
  )
91
94
  @triton.jit
92
95
  def to_sparse_kernel(x,
@@ -101,9 +104,6 @@ def to_sparse_kernel(x,
101
104
  pid_row = tl.program_id(axis=1)
102
105
  pid_col = tl.program_id(axis=2)
103
106
 
104
- # Get valid triton block size
105
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
106
-
107
107
  # Get sparsity index of current output block consisting of its batch, row, and column index
108
108
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
109
109
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -119,24 +119,20 @@ def to_sparse_kernel(x,
119
119
 
120
120
  # Load block from dense tensor
121
121
  blk_d_idx = (spa_bat * x_b_s +
122
- ((pid_row * val_tbs + spa_row * sparsity_block_size +
122
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
123
123
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
124
- ((pid_col * val_tbs + spa_col * sparsity_block_size +
124
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
125
125
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
126
- blk_d_msk = ((blk_d_idx >= 0 and
127
- blk_d_idx < x_b * x_b_s) and
128
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
129
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
126
+ blk_d_msk = (blk_d_idx >= 0 and
127
+ blk_d_idx < x_b * x_b_s)
130
128
  blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
131
129
 
132
130
  # Store block in sparse tensor
133
131
  blk_o_idx = ((pid_blk * o_b_s) +
134
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
135
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
136
- blk_o_msk = ((blk_o_idx >= 0 and
137
- blk_o_idx < (pid_blk + 1) * o_b_s) and
138
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
139
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
132
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
133
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
134
+ blk_o_msk = (blk_o_idx >= 0 and
135
+ blk_o_idx < (pid_blk + 1) * o_b_s)
140
136
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
141
137
 
142
138
 
@@ -227,8 +223,8 @@ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
227
223
  o_b_s, o_r_s, o_c_s = stride(output)
228
224
 
229
225
  triton_grid = lambda meta: [o_b,
230
- triton.cdiv(o_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
231
- triton.cdiv(o_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
226
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
227
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
232
228
 
233
229
  (wrap_triton(to_dense_kernel)[triton_grid]
234
230
  (x,
@@ -251,7 +247,9 @@ def to_dense_backward(ctx, grad_output):
251
247
 
252
248
  @triton.autotune(
253
249
  configs=get_autotune_configs(),
254
- key=[],
250
+ key=["sparsity_block_size"],
251
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
252
+ restore_value=["o"]
255
253
  )
256
254
  @triton.jit
257
255
  def to_dense_kernel(x,
@@ -267,12 +265,9 @@ def to_dense_kernel(x,
267
265
  pid_row = tl.program_id(axis=1)
268
266
  pid_col = tl.program_id(axis=2)
269
267
 
270
- # Get valid triton block size
271
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
272
-
273
268
  # Get sparsity index of current block
274
- spa_row = (pid_row * val_tbs) // sparsity_block_size
275
- spa_col = (pid_col * val_tbs) // sparsity_block_size
269
+ spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
270
+ spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
276
271
 
277
272
  # Get reverse sparsity index for current block
278
273
  rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
@@ -282,22 +277,18 @@ def to_dense_kernel(x,
282
277
  # If block is present commence operations
283
278
  if rev_idx_spa >= 0:
284
279
  blk_idx = (rev_idx_spa * x_b_s +
285
- (((pid_row % (sparsity_block_size // val_tbs)) * val_tbs +
280
+ (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
286
281
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
287
- (((pid_col % (sparsity_block_size // val_tbs)) * val_tbs +
282
+ (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
288
283
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
289
- blk_msk = ((blk_idx >= 0 and
290
- blk_idx < x_b * x_b_s) and
291
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
292
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
293
286
  blk = tl.load(x + blk_idx, mask=blk_msk)
294
287
 
295
288
  o_idx = (pid_blk * o_b_s +
296
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
297
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
298
- o_msk = ((o_idx >= 0 and o_idx < o_b * o_b_s) and
299
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
300
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
289
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
290
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
291
+ o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
301
292
  tl.store(o + o_idx, blk, o_msk)
302
293
 
303
294
 
@@ -401,12 +392,11 @@ def adapt_layout_forward(x: Tensor,
401
392
  s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
402
393
 
403
394
  triton_grid = lambda meta: [o_b,
404
- triton.cdiv(o_r, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
405
- meta["TRITON_BLOCK_SIZE"])),
406
- triton.cdiv(o_c, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
407
- meta["TRITON_BLOCK_SIZE"]))]
395
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
396
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
408
397
 
409
- (wrap_triton(adapt_layout_kernel)[triton_grid]
398
+ # TODO wrap
399
+ (adapt_layout_kernel[triton_grid]
410
400
  (x,
411
401
  x_b, x_b_s, x_r_s, x_c_s,
412
402
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
@@ -432,7 +422,8 @@ def adapt_layout_backward(ctx, grad_output):
432
422
 
433
423
  @triton.autotune(
434
424
  configs=get_autotune_configs(),
435
- key=[],
425
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
426
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
436
427
  reset_to_zero=["o"]
437
428
  )
438
429
  @triton.jit
@@ -451,9 +442,6 @@ def adapt_layout_kernel(x,
451
442
  pid_row = tl.program_id(axis=1)
452
443
  pid_col = tl.program_id(axis=2)
453
444
 
454
- # Get valid triton block size (Triton can only handle 2-valued min)
455
- val_tbs = min(min(sparsity_block_size_from, sparsity_block_size_to), TRITON_BLOCK_SIZE)
456
-
457
445
  # Get position of current sparsity block consisting of its batch, row, and column index
458
446
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
459
447
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -469,8 +457,8 @@ def adapt_layout_kernel(x,
469
457
 
470
458
  # Get equivalent sparsity block in from layout
471
459
  spa_bat_x = spa_bat_o
472
- spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * val_tbs) // sparsity_block_size_from
473
- spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * val_tbs) // sparsity_block_size_from
460
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
461
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
474
462
 
475
463
  # Get reverse sparsity indices for x
476
464
  rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
@@ -482,29 +470,25 @@ def adapt_layout_kernel(x,
482
470
  # If block is present commence operations
483
471
  if rev_idx_spa_x >= 0:
484
472
  # Calculate triton block size shifts
485
- shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * val_tbs)
486
- % sparsity_block_size_from) // val_tbs
487
- shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * val_tbs)
488
- % sparsity_block_size_from) // val_tbs
473
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
474
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
475
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
476
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
489
477
 
490
478
  # Load x values
491
479
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
492
- ((shift_row_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
493
- ((shift_col_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
494
- blk_x_msk = ((blk_x_idx >= 0 and
495
- blk_x_idx < x_b * x_b_s) and
496
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
497
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
480
+ ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
481
+ ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
482
+ blk_x_msk = (blk_x_idx >= 0 and
483
+ blk_x_idx < x_b * x_b_s)
498
484
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
499
485
 
500
486
  # Store output
501
487
  blk_o_idx = ((pid_blk * o_b_s) +
502
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
503
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
504
- blk_o_msk = ((blk_o_idx >= 0 and
505
- blk_o_idx < o_b * o_b_s) and
506
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
507
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
488
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
489
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
490
+ blk_o_msk = (blk_o_idx >= 0 and
491
+ blk_o_idx < o_b * o_b_s)
508
492
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
509
493
 
510
494
 
@@ -6,7 +6,8 @@ from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size
12
13
 
@@ -54,7 +55,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
54
55
  def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
55
56
  dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
56
57
  sparsity_block_size: int) -> Tensor:
57
- output = torch.empty_like(i, dtype=x.dtype)
58
+ output = torch.zeros_like(i, dtype=x.dtype)
58
59
 
59
60
  x_b, x_r, x_c = x.size()
60
61
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -100,7 +101,9 @@ def gather_backward(ctx, grad_output):
100
101
 
101
102
  @triton.autotune(
102
103
  configs=get_autotune_configs(),
103
- key=[],
104
+ key=["sparsity_block_size"],
105
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
106
+ reset_to_zero=["o"]
104
107
  )
105
108
  @triton.jit
106
109
  def gather_kernel(x,
@@ -120,9 +123,6 @@ def gather_kernel(x,
120
123
  pid_row = tl.program_id(axis=1)
121
124
  pid_col = tl.program_id(axis=2)
122
125
 
123
- # Get valid triton block size
124
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
125
-
126
126
  # Get position of current sparsity block consisting of its batch, row, and column index
127
127
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
128
128
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -138,12 +138,10 @@ def gather_kernel(x,
138
138
 
139
139
  # Load index values
140
140
  blk_i_idx = ((pid_blk * i_b_s) +
141
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
142
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
143
- blk_i_msk = ((blk_i_idx >= 0 and
144
- blk_i_idx < i_b * i_b_s) and
145
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
146
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
141
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
142
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
143
+ blk_i_msk = (blk_i_idx >= 0 and
144
+ blk_i_idx < i_b * i_b_s)
147
145
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
148
146
 
149
147
  # Get indices of sparsity blocks and positions within the blocks
@@ -153,9 +151,9 @@ def gather_kernel(x,
153
151
  rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
154
152
  rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
155
153
  rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
156
- dst_row_x = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
154
+ dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
157
155
  .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
158
- dst_col_x = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
156
+ dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
159
157
  .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
160
158
  if dim == 0:
161
159
  rev_dst_bat_x = blk_i
@@ -170,32 +168,26 @@ def gather_kernel(x,
170
168
  rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
171
169
  (rev_dst_row_x * s_l_x_r_s) +
172
170
  (rev_dst_col_x * s_l_x_c_s))
173
- rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0 and
174
- rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s) and
175
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
176
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
171
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
172
+ rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
177
173
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
178
174
 
179
175
  # Load x values
180
176
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
181
177
  dst_row_x +
182
178
  dst_col_x)
183
- blk_x_msk = (((blk_x_idx >= 0 and
184
- blk_x_idx < x_b * x_b_s) and
185
- rev_idx_spa_x_msk != -1) and
186
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
187
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
179
+ blk_x_msk = ((blk_x_idx >= 0 and
180
+ blk_x_idx < x_b * x_b_s) and
181
+ rev_idx_spa_x_msk != -1)
188
182
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
189
183
 
190
184
  # Store output
191
185
  blk_o_idx = ((pid_blk * o_b_s) +
192
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
193
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
194
- blk_o_msk = (((blk_o_idx >= 0 and
195
- blk_o_idx < o_b * o_b_s) and
196
- rev_idx_spa_x_msk != -1) and
197
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
198
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
186
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
187
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
188
+ blk_o_msk = ((blk_o_idx >= 0 and
189
+ blk_o_idx < o_b * o_b_s) and
190
+ rev_idx_spa_x_msk != -1)
199
191
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
200
192
 
201
193
 
@@ -356,7 +348,8 @@ def scatter_reduce_backward(ctx, grad_output):
356
348
 
357
349
  @triton.autotune(
358
350
  configs=get_autotune_configs(),
359
- key=[],
351
+ key=["sparsity_block_size"],
352
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
360
353
  reset_to_zero=["o"]
361
354
  )
362
355
  @triton.jit
@@ -378,9 +371,6 @@ def scatter_reduce_kernel(x,
378
371
  pid_row = tl.program_id(axis=1)
379
372
  pid_col = tl.program_id(axis=2)
380
373
 
381
- # Get valid triton block size
382
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
383
-
384
374
  # Get position of current sparsity block consisting of its batch, row, and column index
385
375
  spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
386
376
  spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
@@ -396,22 +386,18 @@ def scatter_reduce_kernel(x,
396
386
 
397
387
  # Load x values
398
388
  blk_x_idx = ((pid_blk * x_b_s) +
399
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
400
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
401
- blk_x_msk = ((blk_x_idx >= 0 and
402
- blk_x_idx < x_b * x_b_s) and
403
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
404
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
389
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
390
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
391
+ blk_x_msk = (blk_x_idx >= 0 and
392
+ blk_x_idx < x_b * x_b_s)
405
393
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
406
394
 
407
395
  # Load index values
408
396
  blk_i_idx = ((pid_blk * i_b_s) +
409
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
410
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
411
- blk_i_msk = ((blk_i_idx >= 0 and
412
- blk_i_idx < i_b * i_b_s) and
413
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
414
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
397
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
398
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
399
+ blk_i_msk = (blk_i_idx >= 0 and
400
+ blk_i_idx < i_b * i_b_s)
415
401
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
416
402
 
417
403
  # Get indices of sparsity blocks and positions within the blocks
@@ -421,9 +407,9 @@ def scatter_reduce_kernel(x,
421
407
  rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
422
408
  rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
423
409
  rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
424
- dst_row_o = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
410
+ dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
425
411
  .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
426
- dst_col_o = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
412
+ dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
427
413
  .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
428
414
  if dim == 0:
429
415
  rev_dst_bat_o = blk_i
@@ -438,21 +424,17 @@ def scatter_reduce_kernel(x,
438
424
  rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
439
425
  (rev_dst_row_o * s_l_o_r_s) +
440
426
  (rev_dst_col_o * s_l_o_c_s))
441
- rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0 and
442
- rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s) and
443
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
444
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
427
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
428
+ rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
445
429
  rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
446
430
 
447
431
  # Store output
448
432
  blk_o_idx = ((rev_idx_spa_o * o_b_s) +
449
433
  dst_row_o +
450
434
  dst_col_o)
451
- blk_o_msk = (((blk_o_idx >= 0 and
452
- blk_o_idx < o_b * o_b_s) and
453
- rev_idx_spa_o_msk != -1) and
454
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
455
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
435
+ blk_o_msk = ((blk_o_idx >= 0 and
436
+ blk_o_idx < o_b * o_b_s) and
437
+ rev_idx_spa_o_msk != -1)
456
438
 
457
439
  if reduce_op_ind == 0:
458
440
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)