blksprs 2.0rc4__tar.gz → 2.0rc7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {blksprs-2.0rc4 → blksprs-2.0rc7}/PKG-INFO +18 -5
  2. {blksprs-2.0rc4 → blksprs-2.0rc7}/README.md +17 -4
  3. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/layouting/distribution_layout.py +11 -15
  4. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/layouting/sparsity_layout.py +26 -31
  5. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/conversion.py +45 -63
  6. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/distribution.py +38 -57
  7. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/flow.py +22 -33
  8. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/matmul.py +19 -20
  9. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/misc/broadcast_ops.py +15 -19
  10. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/misc/row_wise.py +39 -54
  11. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/softmax.py +30 -44
  12. blksprs-2.0rc7/blksprs/utils/autotuning.py +78 -0
  13. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/utils/tools.py +0 -28
  14. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/utils/validation.py +3 -0
  15. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs.egg-info/PKG-INFO +18 -5
  16. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs.egg-info/SOURCES.txt +1 -0
  17. {blksprs-2.0rc4 → blksprs-2.0rc7}/pyproject.toml +1 -1
  18. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/__init__.py +0 -0
  19. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/partitioning.py +0 -0
  20. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/repeat.py +0 -0
  21. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/ops/transpose.py +0 -0
  22. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/utils/benchmarking.py +0 -0
  23. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/utils/blksprs_tensor.py +0 -0
  24. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs/utils/processing.py +0 -0
  25. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs.egg-info/dependency_links.txt +0 -0
  26. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs.egg-info/requires.txt +0 -0
  27. {blksprs-2.0rc4 → blksprs-2.0rc7}/blksprs.egg-info/top_level.txt +0 -0
  28. {blksprs-2.0rc4 → blksprs-2.0rc7}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc4
3
+ Version: 2.0rc7
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -24,10 +24,10 @@ Requires-Dist: matplotlib; extra == "test"
24
24
 
25
25
  ## Overview
26
26
 
27
- ### News
28
-
29
- 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
- LUTs, autocasting, and makes use of `torch.library.triton_op()`!
27
+ ### News
28
+
29
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
31
31
 
32
32
  ---
33
33
 
@@ -106,6 +106,19 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
106
106
  It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
107
107
  library.
108
108
 
109
+ ## Known Limitations and Issues
110
+
111
+ - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
112
+ In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
113
+ performance.
114
+ Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
115
+ - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
116
+ which could impact graph compilation.
117
+ - There seem to be some issues with autocasting, forcing some operations to manually cast.
118
+ - There will be some slight numerical differences between vanilla and blksprs operations.
119
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
120
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
121
+
109
122
  ## Usage
110
123
 
111
124
  We provide an example below to demonstrate the usage of the library.
@@ -5,10 +5,10 @@
5
5
 
6
6
  ## Overview
7
7
 
8
- ### News
9
-
10
- 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
11
- LUTs, autocasting, and makes use of `torch.library.triton_op()`!
8
+ ### News
9
+
10
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
11
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
12
12
 
13
13
  ---
14
14
 
@@ -87,6 +87,19 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
87
87
  It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
88
88
  library.
89
89
 
90
+ ## Known Limitations and Issues
91
+
92
+ - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
93
+ In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
94
+ performance.
95
+ Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
96
+ - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
97
+ which could impact graph compilation.
98
+ - There seem to be some issues with autocasting, forcing some operations to manually cast.
99
+ - There will be some slight numerical differences between vanilla and blksprs operations.
100
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
101
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
102
+
90
103
  ## Usage
91
104
 
92
105
  We provide an example below to demonstrate the usage of the library.
@@ -4,7 +4,8 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import stride, get_autotune_configs
7
+ from blksprs.utils.tools import stride
8
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
9
  from blksprs.utils.validation import validate_dimensions, validate_device, \
9
10
  validate_contiguous
10
11
 
@@ -47,6 +48,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
47
48
  triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
48
49
  triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
49
50
 
51
+ # TODO wrap
50
52
  (build_distribution_layout_kernel[triton_grid]
51
53
  (indices,
52
54
  i_b, i_b_s, i_r_s, i_c_s,
@@ -62,7 +64,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
62
64
 
63
65
  @triton.autotune(
64
66
  configs=get_autotune_configs(),
65
- key=[],
67
+ key=["sparsity_block_size"],
68
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
66
69
  reset_to_zero=["o"]
67
70
  )
68
71
  @triton.jit
@@ -80,9 +83,6 @@ def build_distribution_layout_kernel(i,
80
83
  pid_row = tl.program_id(axis=1)
81
84
  pid_col = tl.program_id(axis=2)
82
85
 
83
- # Get valid triton block size
84
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
85
-
86
86
  # Get position of current sparsity block consisting of its batch, row, and column index
87
87
  spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
88
88
  spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
@@ -97,12 +97,10 @@ def build_distribution_layout_kernel(i,
97
97
  spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
98
98
 
99
99
  blk_i_idx = (pid_blk * i_b_s +
100
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
- blk_i_msk = ((blk_i_idx >= 0 and
103
- blk_i_idx < i_b * i_b_s) and
104
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
105
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
100
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
+ blk_i_msk = (blk_i_idx >= 0 and
103
+ blk_i_idx < i_b * i_b_s)
106
104
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
107
105
 
108
106
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -120,8 +118,6 @@ def build_distribution_layout_kernel(i,
120
118
  blk_o_idx = ((dst_bat_idx * o_b_s) +
121
119
  (dst_row_idx * o_r_s) +
122
120
  (dst_col_idx * o_c_s))
123
- blk_o_msk = ((blk_o_idx >= 0 and
124
- blk_o_idx < o_b * o_b_s) and
125
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
126
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
121
+ blk_o_msk = (blk_o_idx >= 0 and
122
+ blk_o_idx < o_b * o_b_s)
127
123
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -7,7 +7,8 @@ from torch._library.triton import wrap_triton
7
7
  from triton import language as tl
8
8
 
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import stride, get_autotune_configs
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
11
12
  from blksprs.utils.validation import validate_dimensions, validate_device, \
12
13
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
14
 
@@ -37,10 +38,11 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
37
38
  o_b_s, o_r_s, o_c_s = stride(output)
38
39
 
39
40
  triton_grid = lambda meta: [x_b,
40
- triton.cdiv(x_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
41
- triton.cdiv(x_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
41
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
42
43
 
43
- (wrap_triton(build_sparsity_layout_kernel)[triton_grid]
44
+ # TODO wrap
45
+ (build_sparsity_layout_kernel[triton_grid]
44
46
  (x,
45
47
  x_b, x_b_s, x_r_s, x_c_s,
46
48
  output,
@@ -52,7 +54,8 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
52
54
 
53
55
  @triton.autotune(
54
56
  configs=get_autotune_configs(),
55
- key=[],
57
+ key=["sparsity_block_size"],
58
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
56
59
  reset_to_zero=["o"]
57
60
  )
58
61
  @triton.jit
@@ -67,24 +70,19 @@ def build_sparsity_layout_kernel(x,
67
70
  pid_row = tl.program_id(axis=1)
68
71
  pid_col = tl.program_id(axis=2)
69
72
 
70
- # Get valid triton block size
71
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
72
-
73
73
  # Load x values
74
74
  blk_x_idx = (pid_bat * x_b_s +
75
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
76
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
77
- blk_x_msk = ((blk_x_idx >= 0 and
78
- blk_x_idx < x_b * x_b_s) and
79
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
80
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
75
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
76
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
77
+ blk_x_msk = (blk_x_idx >= 0 and
78
+ blk_x_idx < x_b * x_b_s)
81
79
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
82
80
 
83
81
  # Store sparsity layout value
84
82
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
85
83
  blk_o_idx = (pid_bat * o_b_s +
86
- (((pid_row * val_tbs) // sparsity_block_size) * o_r_s +
87
- ((pid_col * val_tbs) // sparsity_block_size) * o_c_s))
84
+ (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
85
+ ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
88
86
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
89
87
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
90
88
 
@@ -129,10 +127,11 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
129
127
  o_b_s, o_r_s, o_c_s = stride(output)
130
128
 
131
129
  triton_grid = lambda meta: [x_b,
132
- triton.cdiv(x_r, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"])),
133
- triton.cdiv(x_c, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"]))]
130
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
131
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
134
132
 
135
- (wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
133
+ # TODO wrap
134
+ (build_sparsity_layout_adaption_kernel[triton_grid]
136
135
  (x,
137
136
  x_b, x_b_s, x_r_s, x_c_s,
138
137
  sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
@@ -146,7 +145,8 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
146
145
 
147
146
  @triton.autotune(
148
147
  configs=get_autotune_configs(),
149
- key=[],
148
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
149
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
150
150
  reset_to_zero=["o"]
151
151
  )
152
152
  @triton.jit
@@ -163,9 +163,6 @@ def build_sparsity_layout_adaption_kernel(x,
163
163
  pid_row = tl.program_id(axis=1)
164
164
  pid_col = tl.program_id(axis=2)
165
165
 
166
- # Get valid triton block size
167
- val_tbs = min(sparsity_block_size_to, TRITON_BLOCK_SIZE)
168
-
169
166
  # Get sparsity index of current output block consisting of its batch, row, and column index
170
167
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
171
168
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -181,20 +178,18 @@ def build_sparsity_layout_adaption_kernel(x,
181
178
 
182
179
  # Load x values
183
180
  blk_x_idx = ((pid_blk * x_b_s) +
184
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
185
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
186
- blk_x_msk = ((blk_x_idx >= 0 and
187
- blk_x_idx < x_b * x_b_s) and
188
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < sparsity_block_size_from and
189
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < sparsity_block_size_from))
181
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
182
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
183
+ blk_x_msk = (blk_x_idx >= 0 and
184
+ blk_x_idx < x_b * x_b_s)
190
185
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
191
186
 
192
187
  # Store sparsity layout value
193
188
  if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
194
189
  blk_o_idx = ((spa_bat * o_b_s) +
195
- (((pid_row * val_tbs + spa_row * sparsity_block_size_from)
190
+ (((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size_from)
196
191
  // sparsity_block_size_to) * o_r_s) +
197
- (((pid_col * val_tbs + spa_col * sparsity_block_size_from)
192
+ (((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
198
193
  // sparsity_block_size_to) * o_c_s))
199
194
  blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
200
195
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -6,7 +6,8 @@ from triton import language as tl
6
6
 
7
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
12
13
 
@@ -86,7 +87,8 @@ def to_sparse_backward(ctx, grad_output):
86
87
 
87
88
  @triton.autotune(
88
89
  configs=get_autotune_configs(),
89
- key=[],
90
+ key=["sparsity_block_size"],
91
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
90
92
  reset_to_zero=["o"]
91
93
  )
92
94
  @triton.jit
@@ -102,9 +104,6 @@ def to_sparse_kernel(x,
102
104
  pid_row = tl.program_id(axis=1)
103
105
  pid_col = tl.program_id(axis=2)
104
106
 
105
- # Get valid triton block size
106
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
107
-
108
107
  # Get sparsity index of current output block consisting of its batch, row, and column index
109
108
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
110
109
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -120,24 +119,20 @@ def to_sparse_kernel(x,
120
119
 
121
120
  # Load block from dense tensor
122
121
  blk_d_idx = (spa_bat * x_b_s +
123
- ((pid_row * val_tbs + spa_row * sparsity_block_size +
122
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
124
123
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
125
- ((pid_col * val_tbs + spa_col * sparsity_block_size +
124
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
126
125
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
127
- blk_d_msk = ((blk_d_idx >= 0 and
128
- blk_d_idx < x_b * x_b_s) and
129
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
130
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
126
+ blk_d_msk = (blk_d_idx >= 0 and
127
+ blk_d_idx < x_b * x_b_s)
131
128
  blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
132
129
 
133
130
  # Store block in sparse tensor
134
131
  blk_o_idx = ((pid_blk * o_b_s) +
135
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
136
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
137
- blk_o_msk = ((blk_o_idx >= 0 and
138
- blk_o_idx < (pid_blk + 1) * o_b_s) and
139
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
140
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
132
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
133
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
134
+ blk_o_msk = (blk_o_idx >= 0 and
135
+ blk_o_idx < (pid_blk + 1) * o_b_s)
141
136
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
142
137
 
143
138
 
@@ -228,8 +223,8 @@ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
228
223
  o_b_s, o_r_s, o_c_s = stride(output)
229
224
 
230
225
  triton_grid = lambda meta: [o_b,
231
- triton.cdiv(o_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
232
- triton.cdiv(o_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
226
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
227
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
233
228
 
234
229
  (wrap_triton(to_dense_kernel)[triton_grid]
235
230
  (x,
@@ -252,7 +247,8 @@ def to_dense_backward(ctx, grad_output):
252
247
 
253
248
  @triton.autotune(
254
249
  configs=get_autotune_configs(),
255
- key=[],
250
+ key=["sparsity_block_size"],
251
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
256
252
  restore_value=["o"]
257
253
  )
258
254
  @triton.jit
@@ -269,12 +265,9 @@ def to_dense_kernel(x,
269
265
  pid_row = tl.program_id(axis=1)
270
266
  pid_col = tl.program_id(axis=2)
271
267
 
272
- # Get valid triton block size
273
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
274
-
275
268
  # Get sparsity index of current block
276
- spa_row = (pid_row * val_tbs) // sparsity_block_size
277
- spa_col = (pid_col * val_tbs) // sparsity_block_size
269
+ spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
270
+ spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
278
271
 
279
272
  # Get reverse sparsity index for current block
280
273
  rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
@@ -284,22 +277,18 @@ def to_dense_kernel(x,
284
277
  # If block is present commence operations
285
278
  if rev_idx_spa >= 0:
286
279
  blk_idx = (rev_idx_spa * x_b_s +
287
- (((pid_row % (sparsity_block_size // val_tbs)) * val_tbs +
280
+ (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
288
281
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
289
- (((pid_col % (sparsity_block_size // val_tbs)) * val_tbs +
282
+ (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
290
283
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
291
- blk_msk = ((blk_idx >= 0 and
292
- blk_idx < x_b * x_b_s) and
293
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
294
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
295
286
  blk = tl.load(x + blk_idx, mask=blk_msk)
296
287
 
297
288
  o_idx = (pid_blk * o_b_s +
298
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
299
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
300
- o_msk = ((o_idx >= 0 and o_idx < o_b * o_b_s) and
301
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
302
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
289
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
290
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
291
+ o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
303
292
  tl.store(o + o_idx, blk, o_msk)
304
293
 
305
294
 
@@ -403,12 +392,11 @@ def adapt_layout_forward(x: Tensor,
403
392
  s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
404
393
 
405
394
  triton_grid = lambda meta: [o_b,
406
- triton.cdiv(o_r, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
407
- meta["TRITON_BLOCK_SIZE"])),
408
- triton.cdiv(o_c, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
409
- meta["TRITON_BLOCK_SIZE"]))]
395
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
396
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
410
397
 
411
- (wrap_triton(adapt_layout_kernel)[triton_grid]
398
+ # TODO wrap
399
+ (adapt_layout_kernel[triton_grid]
412
400
  (x,
413
401
  x_b, x_b_s, x_r_s, x_c_s,
414
402
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
@@ -434,7 +422,8 @@ def adapt_layout_backward(ctx, grad_output):
434
422
 
435
423
  @triton.autotune(
436
424
  configs=get_autotune_configs(),
437
- key=[],
425
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
426
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
438
427
  reset_to_zero=["o"]
439
428
  )
440
429
  @triton.jit
@@ -453,9 +442,6 @@ def adapt_layout_kernel(x,
453
442
  pid_row = tl.program_id(axis=1)
454
443
  pid_col = tl.program_id(axis=2)
455
444
 
456
- # Get valid triton block size (Triton can only handle 2-valued min)
457
- val_tbs = min(min(sparsity_block_size_from, sparsity_block_size_to), TRITON_BLOCK_SIZE)
458
-
459
445
  # Get position of current sparsity block consisting of its batch, row, and column index
460
446
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
461
447
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -471,8 +457,8 @@ def adapt_layout_kernel(x,
471
457
 
472
458
  # Get equivalent sparsity block in from layout
473
459
  spa_bat_x = spa_bat_o
474
- spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * val_tbs) // sparsity_block_size_from
475
- spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * val_tbs) // sparsity_block_size_from
460
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
461
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
476
462
 
477
463
  # Get reverse sparsity indices for x
478
464
  rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
@@ -484,29 +470,25 @@ def adapt_layout_kernel(x,
484
470
  # If block is present commence operations
485
471
  if rev_idx_spa_x >= 0:
486
472
  # Calculate triton block size shifts
487
- shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * val_tbs)
488
- % sparsity_block_size_from) // val_tbs
489
- shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * val_tbs)
490
- % sparsity_block_size_from) // val_tbs
473
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
474
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
475
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
476
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
491
477
 
492
478
  # Load x values
493
479
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
494
- ((shift_row_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
495
- ((shift_col_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
496
- blk_x_msk = ((blk_x_idx >= 0 and
497
- blk_x_idx < x_b * x_b_s) and
498
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
499
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
480
+ ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
481
+ ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
482
+ blk_x_msk = (blk_x_idx >= 0 and
483
+ blk_x_idx < x_b * x_b_s)
500
484
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
501
485
 
502
486
  # Store output
503
487
  blk_o_idx = ((pid_blk * o_b_s) +
504
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
505
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
506
- blk_o_msk = ((blk_o_idx >= 0 and
507
- blk_o_idx < o_b * o_b_s) and
508
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
509
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
488
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
489
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
490
+ blk_o_msk = (blk_o_idx >= 0 and
491
+ blk_o_idx < o_b * o_b_s)
510
492
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
511
493
 
512
494