blksprs 2.0rc3__tar.gz → 2.0rc6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {blksprs-2.0rc3 → blksprs-2.0rc6}/PKG-INFO +14 -5
  2. {blksprs-2.0rc3 → blksprs-2.0rc6}/README.md +13 -4
  3. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/layouting/distribution_layout.py +11 -15
  4. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/layouting/sparsity_layout.py +26 -31
  5. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/conversion.py +48 -64
  6. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/distribution.py +39 -57
  7. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/flow.py +24 -34
  8. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/matmul.py +21 -21
  9. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/misc/broadcast_ops.py +14 -19
  10. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/misc/row_wise.py +37 -55
  11. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/softmax.py +34 -46
  12. blksprs-2.0rc6/blksprs/utils/autotuning.py +78 -0
  13. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/utils/tools.py +6 -25
  14. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/utils/validation.py +3 -0
  15. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs.egg-info/PKG-INFO +14 -5
  16. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs.egg-info/SOURCES.txt +1 -0
  17. {blksprs-2.0rc3 → blksprs-2.0rc6}/pyproject.toml +1 -1
  18. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/__init__.py +0 -0
  19. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/partitioning.py +0 -0
  20. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/repeat.py +0 -0
  21. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/ops/transpose.py +0 -0
  22. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/utils/benchmarking.py +0 -0
  23. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/utils/blksprs_tensor.py +0 -0
  24. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs/utils/processing.py +0 -0
  25. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs.egg-info/dependency_links.txt +0 -0
  26. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs.egg-info/requires.txt +0 -0
  27. {blksprs-2.0rc3 → blksprs-2.0rc6}/blksprs.egg-info/top_level.txt +0 -0
  28. {blksprs-2.0rc3 → blksprs-2.0rc6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc3
3
+ Version: 2.0rc6
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -24,10 +24,10 @@ Requires-Dist: matplotlib; extra == "test"
24
24
 
25
25
  ## Overview
26
26
 
27
- ### News
28
-
29
- 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
- LUTs, autocasting, and makes use of `torch.library.triton_op()`!
27
+ ### News
28
+
29
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
31
31
 
32
32
  ---
33
33
 
@@ -106,6 +106,15 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
106
106
  It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
107
107
  library.
108
108
 
109
+ ## Known Limitations and Issues
110
+
111
+ - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
112
+ which could impact graph compilation.
113
+ - There seem to be some issues with autocasting, forcing some operations to manually cast.
114
+ - There will be some slight numerical differences between vanilla and blksprs operations.
115
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
116
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
117
+
109
118
  ## Usage
110
119
 
111
120
  We provide an example below to demonstrate the usage of the library.
@@ -5,10 +5,10 @@
5
5
 
6
6
  ## Overview
7
7
 
8
- ### News
9
-
10
- 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
11
- LUTs, autocasting, and makes use of `torch.library.triton_op()`!
8
+ ### News
9
+
10
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
11
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
12
12
 
13
13
  ---
14
14
 
@@ -87,6 +87,15 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
87
87
  It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
88
88
  library.
89
89
 
90
+ ## Known Limitations and Issues
91
+
92
+ - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
93
+ which could impact graph compilation.
94
+ - There seem to be some issues with autocasting, forcing some operations to manually cast.
95
+ - There will be some slight numerical differences between vanilla and blksprs operations.
96
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
97
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
98
+
90
99
  ## Usage
91
100
 
92
101
  We provide an example below to demonstrate the usage of the library.
@@ -4,7 +4,8 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import stride, get_autotune_configs
7
+ from blksprs.utils.tools import stride
8
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
9
  from blksprs.utils.validation import validate_dimensions, validate_device, \
9
10
  validate_contiguous
10
11
 
@@ -47,6 +48,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
47
48
  triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
48
49
  triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
49
50
 
51
+ # TODO wrap
50
52
  (build_distribution_layout_kernel[triton_grid]
51
53
  (indices,
52
54
  i_b, i_b_s, i_r_s, i_c_s,
@@ -62,7 +64,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
62
64
 
63
65
  @triton.autotune(
64
66
  configs=get_autotune_configs(),
65
- key=[],
67
+ key=["sparsity_block_size"],
68
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
66
69
  reset_to_zero=["o"]
67
70
  )
68
71
  @triton.jit
@@ -80,9 +83,6 @@ def build_distribution_layout_kernel(i,
80
83
  pid_row = tl.program_id(axis=1)
81
84
  pid_col = tl.program_id(axis=2)
82
85
 
83
- # Get valid triton block size
84
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
85
-
86
86
  # Get position of current sparsity block consisting of its batch, row, and column index
87
87
  spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
88
88
  spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
@@ -97,12 +97,10 @@ def build_distribution_layout_kernel(i,
97
97
  spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
98
98
 
99
99
  blk_i_idx = (pid_blk * i_b_s +
100
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
- blk_i_msk = ((blk_i_idx >= 0 and
103
- blk_i_idx < i_b * i_b_s) and
104
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
105
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
100
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
+ blk_i_msk = (blk_i_idx >= 0 and
103
+ blk_i_idx < i_b * i_b_s)
106
104
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
107
105
 
108
106
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -120,8 +118,6 @@ def build_distribution_layout_kernel(i,
120
118
  blk_o_idx = ((dst_bat_idx * o_b_s) +
121
119
  (dst_row_idx * o_r_s) +
122
120
  (dst_col_idx * o_c_s))
123
- blk_o_msk = ((blk_o_idx >= 0 and
124
- blk_o_idx < o_b * o_b_s) and
125
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
126
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
121
+ blk_o_msk = (blk_o_idx >= 0 and
122
+ blk_o_idx < o_b * o_b_s)
127
123
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -7,7 +7,8 @@ from torch._library.triton import wrap_triton
7
7
  from triton import language as tl
8
8
 
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import stride, get_autotune_configs
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
11
12
  from blksprs.utils.validation import validate_dimensions, validate_device, \
12
13
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
14
 
@@ -37,10 +38,11 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
37
38
  o_b_s, o_r_s, o_c_s = stride(output)
38
39
 
39
40
  triton_grid = lambda meta: [x_b,
40
- triton.cdiv(x_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
41
- triton.cdiv(x_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
41
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
42
43
 
43
- (wrap_triton(build_sparsity_layout_kernel)[triton_grid]
44
+ # TODO wrap
45
+ (build_sparsity_layout_kernel[triton_grid]
44
46
  (x,
45
47
  x_b, x_b_s, x_r_s, x_c_s,
46
48
  output,
@@ -52,7 +54,8 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
52
54
 
53
55
  @triton.autotune(
54
56
  configs=get_autotune_configs(),
55
- key=[],
57
+ key=["sparsity_block_size"],
58
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
56
59
  reset_to_zero=["o"]
57
60
  )
58
61
  @triton.jit
@@ -67,24 +70,19 @@ def build_sparsity_layout_kernel(x,
67
70
  pid_row = tl.program_id(axis=1)
68
71
  pid_col = tl.program_id(axis=2)
69
72
 
70
- # Get valid triton block size
71
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
72
-
73
73
  # Load x values
74
74
  blk_x_idx = (pid_bat * x_b_s +
75
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
76
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
77
- blk_x_msk = ((blk_x_idx >= 0 and
78
- blk_x_idx < x_b * x_b_s) and
79
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
80
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
75
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
76
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
77
+ blk_x_msk = (blk_x_idx >= 0 and
78
+ blk_x_idx < x_b * x_b_s)
81
79
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
82
80
 
83
81
  # Store sparsity layout value
84
82
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
85
83
  blk_o_idx = (pid_bat * o_b_s +
86
- (((pid_row * val_tbs) // sparsity_block_size) * o_r_s +
87
- ((pid_col * val_tbs) // sparsity_block_size) * o_c_s))
84
+ (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
85
+ ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
88
86
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
89
87
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
90
88
 
@@ -129,10 +127,11 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
129
127
  o_b_s, o_r_s, o_c_s = stride(output)
130
128
 
131
129
  triton_grid = lambda meta: [x_b,
132
- triton.cdiv(x_r, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"])),
133
- triton.cdiv(x_c, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"]))]
130
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
131
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
134
132
 
135
- (wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
133
+ # TODO wrap
134
+ (build_sparsity_layout_adaption_kernel[triton_grid]
136
135
  (x,
137
136
  x_b, x_b_s, x_r_s, x_c_s,
138
137
  sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
@@ -146,7 +145,8 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
146
145
 
147
146
  @triton.autotune(
148
147
  configs=get_autotune_configs(),
149
- key=[],
148
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
149
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
150
150
  reset_to_zero=["o"]
151
151
  )
152
152
  @triton.jit
@@ -163,9 +163,6 @@ def build_sparsity_layout_adaption_kernel(x,
163
163
  pid_row = tl.program_id(axis=1)
164
164
  pid_col = tl.program_id(axis=2)
165
165
 
166
- # Get valid triton block size
167
- val_tbs = min(sparsity_block_size_to, TRITON_BLOCK_SIZE)
168
-
169
166
  # Get sparsity index of current output block consisting of its batch, row, and column index
170
167
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
171
168
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -181,20 +178,18 @@ def build_sparsity_layout_adaption_kernel(x,
181
178
 
182
179
  # Load x values
183
180
  blk_x_idx = ((pid_blk * x_b_s) +
184
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
185
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
186
- blk_x_msk = ((blk_x_idx >= 0 and
187
- blk_x_idx < x_b * x_b_s) and
188
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < sparsity_block_size_from and
189
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < sparsity_block_size_from))
181
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
182
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
183
+ blk_x_msk = (blk_x_idx >= 0 and
184
+ blk_x_idx < x_b * x_b_s)
190
185
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
191
186
 
192
187
  # Store sparsity layout value
193
188
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
194
189
  blk_o_idx = ((spa_bat * o_b_s) +
195
- (((pid_row * val_tbs + spa_row * sparsity_block_size_from)
190
+ (((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size_from)
196
191
  // sparsity_block_size_to) * o_r_s) +
197
- (((pid_col * val_tbs + spa_col * sparsity_block_size_from)
192
+ (((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
198
193
  // sparsity_block_size_to) * o_c_s))
199
194
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
200
195
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -6,7 +6,8 @@ from triton import language as tl
6
6
 
7
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
12
13
 
@@ -54,7 +55,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
54
55
  @triton_op("blksprs::to_sparse", mutates_args={})
55
56
  def to_sparse_forward(x: Tensor, _: Tensor,
56
57
  sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
57
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
58
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
58
59
  dtype=x.dtype, device=x.device)
59
60
 
60
61
  x_b, x_r, x_c = x.size()
@@ -86,7 +87,9 @@ def to_sparse_backward(ctx, grad_output):
86
87
 
87
88
  @triton.autotune(
88
89
  configs=get_autotune_configs(),
89
- key=[],
90
+ key=["sparsity_block_size"],
91
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
92
+ reset_to_zero=["o"]
90
93
  )
91
94
  @triton.jit
92
95
  def to_sparse_kernel(x,
@@ -101,9 +104,6 @@ def to_sparse_kernel(x,
101
104
  pid_row = tl.program_id(axis=1)
102
105
  pid_col = tl.program_id(axis=2)
103
106
 
104
- # Get valid triton block size
105
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
106
-
107
107
  # Get sparsity index of current output block consisting of its batch, row, and column index
108
108
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
109
109
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -119,24 +119,20 @@ def to_sparse_kernel(x,
119
119
 
120
120
  # Load block from dense tensor
121
121
  blk_d_idx = (spa_bat * x_b_s +
122
- ((pid_row * val_tbs + spa_row * sparsity_block_size +
122
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
123
123
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
124
- ((pid_col * val_tbs + spa_col * sparsity_block_size +
124
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
125
125
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
126
- blk_d_msk = ((blk_d_idx >= 0 and
127
- blk_d_idx < x_b * x_b_s) and
128
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
129
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
126
+ blk_d_msk = (blk_d_idx >= 0 and
127
+ blk_d_idx < x_b * x_b_s)
130
128
  blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
131
129
 
132
130
  # Store block in sparse tensor
133
131
  blk_o_idx = ((pid_blk * o_b_s) +
134
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
135
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
136
- blk_o_msk = ((blk_o_idx >= 0 and
137
- blk_o_idx < (pid_blk + 1) * o_b_s) and
138
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
139
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
132
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
133
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
134
+ blk_o_msk = (blk_o_idx >= 0 and
135
+ blk_o_idx < (pid_blk + 1) * o_b_s)
140
136
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
141
137
 
142
138
 
@@ -227,8 +223,8 @@ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
227
223
  o_b_s, o_r_s, o_c_s = stride(output)
228
224
 
229
225
  triton_grid = lambda meta: [o_b,
230
- triton.cdiv(o_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
231
- triton.cdiv(o_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
226
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
227
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
232
228
 
233
229
  (wrap_triton(to_dense_kernel)[triton_grid]
234
230
  (x,
@@ -251,7 +247,9 @@ def to_dense_backward(ctx, grad_output):
251
247
 
252
248
  @triton.autotune(
253
249
  configs=get_autotune_configs(),
254
- key=[],
250
+ key=["sparsity_block_size"],
251
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
252
+ restore_value=["o"]
255
253
  )
256
254
  @triton.jit
257
255
  def to_dense_kernel(x,
@@ -267,12 +265,9 @@ def to_dense_kernel(x,
267
265
  pid_row = tl.program_id(axis=1)
268
266
  pid_col = tl.program_id(axis=2)
269
267
 
270
- # Get valid triton block size
271
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
272
-
273
268
  # Get sparsity index of current block
274
- spa_row = (pid_row * val_tbs) // sparsity_block_size
275
- spa_col = (pid_col * val_tbs) // sparsity_block_size
269
+ spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
270
+ spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
276
271
 
277
272
  # Get reverse sparsity index for current block
278
273
  rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
@@ -282,22 +277,18 @@ def to_dense_kernel(x,
282
277
  # If block is present commence operations
283
278
  if rev_idx_spa >= 0:
284
279
  blk_idx = (rev_idx_spa * x_b_s +
285
- (((pid_row % (sparsity_block_size // val_tbs)) * val_tbs +
280
+ (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
286
281
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
287
- (((pid_col % (sparsity_block_size // val_tbs)) * val_tbs +
282
+ (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
288
283
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
289
- blk_msk = ((blk_idx >= 0 and
290
- blk_idx < x_b * x_b_s) and
291
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
292
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
293
286
  blk = tl.load(x + blk_idx, mask=blk_msk)
294
287
 
295
288
  o_idx = (pid_blk * o_b_s +
296
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
297
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
298
- o_msk = ((o_idx >= 0 and o_idx < o_b * o_b_s) and
299
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
300
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
289
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
290
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
291
+ o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
301
292
  tl.store(o + o_idx, blk, o_msk)
302
293
 
303
294
 
@@ -401,12 +392,11 @@ def adapt_layout_forward(x: Tensor,
401
392
  s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
402
393
 
403
394
  triton_grid = lambda meta: [o_b,
404
- triton.cdiv(o_r, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
405
- meta["TRITON_BLOCK_SIZE"])),
406
- triton.cdiv(o_c, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
407
- meta["TRITON_BLOCK_SIZE"]))]
395
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
396
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
408
397
 
409
- (wrap_triton(adapt_layout_kernel)[triton_grid]
398
+ # TODO wrap
399
+ (adapt_layout_kernel[triton_grid]
410
400
  (x,
411
401
  x_b, x_b_s, x_r_s, x_c_s,
412
402
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
@@ -432,7 +422,8 @@ def adapt_layout_backward(ctx, grad_output):
432
422
 
433
423
  @triton.autotune(
434
424
  configs=get_autotune_configs(),
435
- key=[],
425
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
426
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
436
427
  reset_to_zero=["o"]
437
428
  )
438
429
  @triton.jit
@@ -451,9 +442,6 @@ def adapt_layout_kernel(x,
451
442
  pid_row = tl.program_id(axis=1)
452
443
  pid_col = tl.program_id(axis=2)
453
444
 
454
- # Get valid triton block size (Triton can only handle 2-valued min)
455
- val_tbs = min(min(sparsity_block_size_from, sparsity_block_size_to), TRITON_BLOCK_SIZE)
456
-
457
445
  # Get position of current sparsity block consisting of its batch, row, and column index
458
446
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
459
447
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -469,8 +457,8 @@ def adapt_layout_kernel(x,
469
457
 
470
458
  # Get equivalent sparsity block in from layout
471
459
  spa_bat_x = spa_bat_o
472
- spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * val_tbs) // sparsity_block_size_from
473
- spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * val_tbs) // sparsity_block_size_from
460
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
461
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
474
462
 
475
463
  # Get reverse sparsity indices for x
476
464
  rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
@@ -482,29 +470,25 @@ def adapt_layout_kernel(x,
482
470
  # If block is present commence operations
483
471
  if rev_idx_spa_x >= 0:
484
472
  # Calculate triton block size shifts
485
- shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * val_tbs)
486
- % sparsity_block_size_from) // val_tbs
487
- shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * val_tbs)
488
- % sparsity_block_size_from) // val_tbs
473
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
474
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
475
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
476
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
489
477
 
490
478
  # Load x values
491
479
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
492
- ((shift_row_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
493
- ((shift_col_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
494
- blk_x_msk = ((blk_x_idx >= 0 and
495
- blk_x_idx < x_b * x_b_s) and
496
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
497
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
480
+ ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
481
+ ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
482
+ blk_x_msk = (blk_x_idx >= 0 and
483
+ blk_x_idx < x_b * x_b_s)
498
484
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
499
485
 
500
486
  # Store output
501
487
  blk_o_idx = ((pid_blk * o_b_s) +
502
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
503
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
504
- blk_o_msk = ((blk_o_idx >= 0 and
505
- blk_o_idx < o_b * o_b_s) and
506
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
507
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
488
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
489
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
490
+ blk_o_msk = (blk_o_idx >= 0 and
491
+ blk_o_idx < o_b * o_b_s)
508
492
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
509
493
 
510
494