blksprs 2.0rc4__py3-none-any.whl → 2.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/flow.py CHANGED
@@ -5,7 +5,8 @@ from torch._library import triton_op
5
5
  from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
- from blksprs.utils.tools import stride, get_autotune_configs
8
+ from blksprs.utils.tools import stride
9
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
10
 
10
11
 
11
12
  @triton_op("blksprs::flow_pull", mutates_args={})
@@ -43,7 +44,8 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
43
44
 
44
45
  @triton.autotune(
45
46
  configs=get_autotune_configs(),
46
- key=[],
47
+ key=["sparsity_block_size"],
48
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
47
49
  reset_to_zero=["o"]
48
50
  )
49
51
  @triton.jit
@@ -61,9 +63,6 @@ def flow_pull_kernel(x,
61
63
  pid_row = tl.program_id(axis=1)
62
64
  pid_col = tl.program_id(axis=2)
63
65
 
64
- # Get valid triton block size
65
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
66
-
67
66
  # Get sparsity index of current output block consisting of its batch, row, and column index
68
67
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
69
68
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -86,21 +85,17 @@ def flow_pull_kernel(x,
86
85
 
87
86
  if rev_idx_spa >= 0:
88
87
  blk_x_idx = (rev_idx_spa * x_b_s +
89
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
90
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
91
- blk_x_msk = ((blk_x_idx >= 0 and
92
- blk_x_idx < x_b * x_b_s) and
93
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
94
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
88
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
89
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
90
+ blk_x_msk = (blk_x_idx >= 0 and
91
+ blk_x_idx < x_b * x_b_s)
95
92
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
96
93
 
97
94
  blk_o_idx = (pid_blk * o_b_s +
98
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
99
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
100
- blk_o_msk = ((blk_o_idx >= 0 and
101
- blk_o_idx < o_b * o_b_s) and
102
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
103
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
95
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
96
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
97
+ blk_o_msk = (blk_o_idx >= 0 and
98
+ blk_o_idx < o_b * o_b_s)
104
99
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
105
100
 
106
101
 
@@ -138,7 +133,8 @@ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor
138
133
 
139
134
  @triton.autotune(
140
135
  configs=get_autotune_configs(),
141
- key=[],
136
+ key=["sparsity_block_size"],
137
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
142
138
  reset_to_zero=["o"]
143
139
  )
144
140
  @triton.jit
@@ -156,9 +152,6 @@ def flow_push_kernel(x,
156
152
  pid_row = tl.program_id(axis=1)
157
153
  pid_col = tl.program_id(axis=2)
158
154
 
159
- # Get valid triton block size
160
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
161
-
162
155
  # Get sparsity index of current input block consisting of its batch, row, and column index
163
156
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
164
157
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -181,19 +174,15 @@ def flow_push_kernel(x,
181
174
 
182
175
  if rev_idx_spa >= 0:
183
176
  blk_x_idx = (pid_blk * x_b_s +
184
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
185
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
186
- blk_x_msk = ((blk_x_idx >= 0 and
187
- blk_x_idx < x_b * x_b_s) and
188
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
189
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
177
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
178
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
179
+ blk_x_msk = (blk_x_idx >= 0 and
180
+ blk_x_idx < x_b * x_b_s)
190
181
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
191
182
 
192
183
  blk_o_idx = (rev_idx_spa * o_b_s +
193
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
194
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
195
- blk_o_msk = ((blk_o_idx >= 0 and
196
- blk_o_idx < o_b * o_b_s) and
197
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
198
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
184
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
185
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
186
+ blk_o_msk = (blk_o_idx >= 0 and
187
+ blk_o_idx < o_b * o_b_s)
199
188
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
blksprs/ops/matmul.py CHANGED
@@ -6,7 +6,8 @@ from triton import language as tl
6
6
 
7
7
  from blksprs.ops.transpose import transpose
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float
12
13
 
@@ -117,7 +118,8 @@ def matmul_backward(ctx, grad_output):
117
118
 
118
119
  @triton.autotune(
119
120
  configs=get_autotune_configs(),
120
- key=[],
121
+ key=["sparsity_block_size"],
122
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
121
123
  reset_to_zero=["o"]
122
124
  )
123
125
  @triton.jit
@@ -141,9 +143,6 @@ def matmul_kernel(x,
141
143
  pid_row = tl.program_id(axis=1)
142
144
  pid_col = tl.program_id(axis=2)
143
145
 
144
- # Get valid triton block size
145
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
146
-
147
146
  # Get position of current sparsity block consisting of its batch, row, and column index
148
147
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
149
148
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -161,11 +160,11 @@ def matmul_kernel(x,
161
160
  buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
162
161
 
163
162
  # Slide over triton block sized segments of input tensors
164
- for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, val_tbs)):
163
+ for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
165
164
  # Convert to segment index of sparsity layout
166
- i_seg_spa = (i_seg_tri * val_tbs) // sparsity_block_size
165
+ i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
167
166
  # Calculate the triton segment index within a block
168
- i_seg_tri_mod = i_seg_tri % (sparsity_block_size // val_tbs)
167
+ i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
169
168
 
170
169
  # Get reverse sparsity indices for input tensors x and y
171
170
  # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
@@ -185,23 +184,23 @@ def matmul_kernel(x,
185
184
  # If both blocks are present commence calculation
186
185
  if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
187
186
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
188
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
189
- ((i_seg_tri_mod * val_tbs +
187
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
188
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
190
189
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
191
190
  blk_x_msk = ((blk_x_idx >= 0 and
192
191
  blk_x_idx < x_b * x_b_s) and
193
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
194
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
192
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
193
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
195
194
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
196
195
 
197
196
  blk_y_idx = ((rev_idx_spa_y * y_b_s) +
198
- ((i_seg_tri_mod * val_tbs +
197
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
199
198
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
200
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
199
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
201
200
  blk_y_msk = ((blk_y_idx >= 0 and
202
201
  blk_y_idx < y_b * y_b_s) and
203
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
204
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
202
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
203
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
205
204
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
206
205
 
207
206
  # Perform matrix multiplication
@@ -212,12 +211,12 @@ def matmul_kernel(x,
212
211
 
213
212
  # Store output
214
213
  blk_o_idx = ((pid_blk * o_b_s) +
215
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
216
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
214
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
215
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
217
216
  blk_o_msk = ((blk_o_idx >= 0 and
218
217
  blk_o_idx < o_b * o_b_s) and
219
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
220
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
218
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
219
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
221
220
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
222
221
 
223
222
 
@@ -6,7 +6,8 @@ from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
11
  from blksprs.utils.validation import validate_contiguous, validate_device, \
11
12
  validate_sparsity_block_size
12
13
 
@@ -87,7 +88,8 @@ def broadcast_add_forward(x: Tensor, y: Tensor,
87
88
 
88
89
  @triton.autotune(
89
90
  configs=get_autotune_configs(),
90
- key=[],
91
+ key=["sparsity_block_size"],
92
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
91
93
  reset_to_zero=["o"]
92
94
  )
93
95
  @triton.jit
@@ -105,9 +107,6 @@ def broadcast_add_kernel(x,
105
107
  pid_row = tl.program_id(axis=1)
106
108
  pid_col = tl.program_id(axis=2)
107
109
 
108
- # Get valid triton block size
109
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
110
-
111
110
  # Get position of current sparsity block consisting of its batch, row, and column index
112
111
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
113
112
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -123,20 +122,18 @@ def broadcast_add_kernel(x,
123
122
 
124
123
  # Load x block
125
124
  blk_x_idx = (spa_bat_o * x_b_s +
126
- ((pid_row * val_tbs + spa_row_o * sparsity_block_size +
125
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
127
126
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
128
- blk_x_msk = ((blk_x_idx >= 0 and
129
- blk_x_idx < x_b * x_b_s) and
130
- (tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
127
+ blk_x_msk = (blk_x_idx >= 0 and
128
+ blk_x_idx < x_b * x_b_s)
131
129
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
132
130
 
133
131
  # Load y block
134
132
  blk_y_idx = (spa_bat_o * y_b_s +
135
- ((pid_col * val_tbs + spa_col_o * sparsity_block_size +
133
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
136
134
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
137
- blk_y_msk = ((blk_y_idx >= 0 and
138
- blk_y_idx < y_b * y_b_s) and
139
- (tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
135
+ blk_y_msk = (blk_y_idx >= 0 and
136
+ blk_y_idx < y_b * y_b_s)
140
137
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
141
138
 
142
139
  # Compute sum
@@ -145,10 +142,8 @@ def broadcast_add_kernel(x,
145
142
 
146
143
  # Store result
147
144
  blk_o_idx = ((pid_blk * o_b_s) +
148
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
149
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
150
- blk_o_msk = ((blk_o_idx >= 0 and
151
- blk_o_idx < o_b * o_b_s) and
152
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
153
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
145
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
146
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
147
+ blk_o_msk = (blk_o_idx >= 0 and
148
+ blk_o_idx < o_b * o_b_s)
154
149
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -5,7 +5,8 @@ from torch._library.triton import wrap_triton, triton_op
5
5
  from triton import language as tl
6
6
 
7
7
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import stride, get_autotune_configs
8
+ from blksprs.utils.tools import stride, get_autocast_min_val
9
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
10
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
10
11
  validate_sparsity_block_size
11
12
 
@@ -96,7 +97,8 @@ def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
96
97
 
97
98
  @triton.autotune(
98
99
  configs=get_autotune_configs(),
99
- key=[],
100
+ key=["sparsity_block_size"],
101
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
100
102
  reset_to_zero=["o"]
101
103
  )
102
104
  @triton.jit
@@ -114,9 +116,6 @@ def row_wise_sum_kernel(x,
114
116
  pid_row = tl.program_id(axis=1)
115
117
  pid_col = tl.program_id(axis=2)
116
118
 
117
- # Get valid triton block size
118
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
119
-
120
119
  # Get position of current sparsity block consisting of its batch and row index
121
120
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
122
121
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -137,23 +136,19 @@ def row_wise_sum_kernel(x,
137
136
  return
138
137
 
139
138
  blk_idx = ((pid_blk * x_b_s) +
140
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
141
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
142
- blk_msk = ((blk_idx >= 0 and
143
- blk_idx < x_b * x_b_s) and
144
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
145
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
139
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
140
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
141
+ blk_msk = (blk_idx >= 0 and
142
+ blk_idx < x_b * x_b_s)
146
143
  blk = tl.load(x + blk_idx, mask=blk_msk)
147
144
 
148
145
  buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
149
146
 
150
147
  o_idx = (rev_idx_spa * o_b_s +
151
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
152
149
  (tl.arange(0, 1))[None, :])
153
- o_msk = ((o_idx >= 0 and
154
- o_idx < o_b * o_b_s) and
155
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
156
- tl.arange(0, 1)[None, :] < val_tbs))
150
+ o_msk = (o_idx >= 0 and
151
+ o_idx < o_b * o_b_s)
157
152
  tl.atomic_add(o + o_idx, buf, o_msk)
158
153
 
159
154
 
@@ -214,7 +209,7 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
214
209
  output = torch.full(size=(n_sparse_blocks_output,
215
210
  sparsity_block_size,
216
211
  1 if flag_slice_only else sparsity_block_size),
217
- fill_value=float("-inf"),
212
+ fill_value=get_autocast_min_val(),
218
213
  device=x.device)
219
214
 
220
215
  x_b, x_r, x_c = x.size()
@@ -245,7 +240,8 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
245
240
 
246
241
  @triton.autotune(
247
242
  configs=get_autotune_configs(),
248
- key=[],
243
+ key=["sparsity_block_size"],
244
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
249
245
  restore_value=["o"]
250
246
  )
251
247
  @triton.jit
@@ -263,9 +259,6 @@ def row_wise_max_kernel(x,
263
259
  pid_row = tl.program_id(axis=1)
264
260
  pid_col = tl.program_id(axis=2)
265
261
 
266
- # Get valid triton block size
267
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
268
-
269
262
  # Get position of current sparsity block consisting of its batch and row index
270
263
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
271
264
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -286,23 +279,19 @@ def row_wise_max_kernel(x,
286
279
  return
287
280
 
288
281
  blk_idx = ((pid_blk * x_b_s) +
289
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
290
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
291
- blk_msk = ((blk_idx >= 0 and
292
- blk_idx < x_b * x_b_s) and
293
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
294
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
282
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
283
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
295
286
  blk = tl.load(x + blk_idx, mask=blk_msk)
296
287
 
297
288
  buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
298
289
 
299
290
  o_idx = (rev_idx_spa * o_b_s +
300
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
291
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
301
292
  (tl.arange(0, 1))[None, :])
302
- o_msk = ((o_idx >= 0 and
303
- o_idx < o_b * o_b_s) and
304
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
305
- tl.arange(0, 1)[None, :] < val_tbs))
293
+ o_msk = (o_idx >= 0 and
294
+ o_idx < o_b * o_b_s)
306
295
  tl.atomic_max(o + o_idx, buf, o_msk)
307
296
 
308
297
 
@@ -371,7 +360,7 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
371
360
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
372
361
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
373
362
 
374
- (kernel_blocksparse_row_wise_add[triton_grid]
363
+ (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
375
364
  (x,
376
365
  x_b, x_b_s, x_r_s, x_c_s,
377
366
  sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
@@ -387,7 +376,8 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
387
376
 
388
377
  @triton.autotune(
389
378
  configs=get_autotune_configs(),
390
- key=[],
379
+ key=["sparsity_block_size"],
380
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
391
381
  reset_to_zero=["o"]
392
382
  )
393
383
  @triton.jit
@@ -406,9 +396,6 @@ def kernel_blocksparse_row_wise_add(x,
406
396
  pid_row = tl.program_id(axis=1)
407
397
  pid_col = tl.program_id(axis=2)
408
398
 
409
- # Get valid triton block size
410
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
411
-
412
399
  # Get position of current sparsity block consisting of its batch and row index
413
400
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
414
401
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -430,22 +417,18 @@ def kernel_blocksparse_row_wise_add(x,
430
417
 
431
418
  # Load x block
432
419
  blk_x_idx = ((pid_blk * x_b_s) +
433
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
434
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
435
- blk_x_msk = ((blk_x_idx >= 0 and
436
- blk_x_idx < x_b * x_b_s) and
437
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
438
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
420
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
421
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
422
+ blk_x_msk = (blk_x_idx >= 0 and
423
+ blk_x_idx < x_b * x_b_s)
439
424
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
440
425
 
441
426
  # Load sum block
442
427
  blk_s_idx = (rev_idx_spa_s * y_b_s +
443
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
428
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
444
429
  (tl.arange(0, 1) * y_c_s)[None, :])
445
- blk_s_msk = ((blk_s_idx >= 0 and
446
- blk_s_idx < y_b * y_b_s) and
447
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
448
- tl.arange(0, 1)[None, :] < val_tbs))
430
+ blk_s_msk = (blk_s_idx >= 0 and
431
+ blk_s_idx < y_b * y_b_s)
449
432
  blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
450
433
 
451
434
  # Compute exp
@@ -453,10 +436,8 @@ def kernel_blocksparse_row_wise_add(x,
453
436
 
454
437
  # Store block
455
438
  blk_o_idx = ((pid_blk * o_b_s) +
456
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
457
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
458
- blk_o_msk = ((blk_o_idx >= 0 and
459
- blk_o_idx < o_b * o_b_s) and
460
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
461
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
439
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
440
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
441
+ blk_o_msk = (blk_o_idx >= 0 and
442
+ blk_o_idx < o_b * o_b_s)
462
443
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
blksprs/ops/softmax.py CHANGED
@@ -7,7 +7,8 @@ from triton import language as tl
7
7
 
8
8
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import stride, get_autotune_configs
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
11
12
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
12
13
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
13
14
 
@@ -114,7 +115,8 @@ def softmax_backward(ctx, grad_output):
114
115
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
115
116
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
116
117
 
117
- (wrap_triton(softmax_kernel_grad)[triton_grid]
118
+ # TODO wrap
119
+ (softmax_kernel_grad[triton_grid]
118
120
  (grad_output,
119
121
  o_b, o_b_s, o_r_s, o_c_s,
120
122
  o,
@@ -133,7 +135,8 @@ def softmax_backward(ctx, grad_output):
133
135
 
134
136
  @triton.autotune(
135
137
  configs=get_autotune_configs(),
136
- key=[],
138
+ key=["sparsity_block_size"],
139
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
137
140
  reset_to_zero=["o"]
138
141
  )
139
142
  @triton.jit
@@ -151,9 +154,6 @@ def softmax_kernel(x,
151
154
  pid_row = tl.program_id(axis=1)
152
155
  pid_col = tl.program_id(axis=2)
153
156
 
154
- # Get valid triton block size
155
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
156
-
157
157
  # Get position of current sparsity block consisting of its batch and row index
158
158
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
159
159
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -172,22 +172,18 @@ def softmax_kernel(x,
172
172
  if rev_idx_spa_s >= 0:
173
173
  # Load x block
174
174
  blk_x_idx = ((pid_blk * x_b_s) +
175
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
176
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
177
- blk_x_msk = ((blk_x_idx >= 0 and
178
- blk_x_idx < x_b * x_b_s) and
179
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
180
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
175
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
176
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
177
+ blk_x_msk = (blk_x_idx >= 0 and
178
+ blk_x_idx < x_b * x_b_s)
181
179
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
182
180
 
183
181
  # Load sum block
184
182
  blk_s_idx = (rev_idx_spa_s * s_b_s +
185
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
183
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
186
184
  (tl.arange(0, 1) * s_c_s)[None, :])
187
- blk_s_msk = ((blk_s_idx >= 0 and
188
- blk_s_idx < s_b * s_b_s) and
189
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
190
- tl.arange(0, 1)[None, :] < val_tbs))
185
+ blk_s_msk = (blk_s_idx >= 0 and
186
+ blk_s_idx < s_b * s_b_s)
191
187
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
192
188
 
193
189
  # Compute softmax
@@ -199,7 +195,8 @@ def softmax_kernel(x,
199
195
 
200
196
  @triton.autotune(
201
197
  configs=get_autotune_configs(),
202
- key=[],
198
+ key=["sparsity_block_size"],
199
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
203
200
  reset_to_zero=["o"]
204
201
  )
205
202
  @triton.jit
@@ -221,9 +218,6 @@ def softmax_kernel_grad(g,
221
218
  pid_row = tl.program_id(axis=1)
222
219
  pid_col = tl.program_id(axis=2)
223
220
 
224
- # Get valid triton block size
225
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
226
-
227
221
  # Get position of current sparsity block consisting of its batch and row index
228
222
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
229
223
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -240,41 +234,33 @@ def softmax_kernel_grad(g,
240
234
 
241
235
  if rev_idx_spa_s >= 0:
242
236
  blk_s_idx = (rev_idx_spa_s * s_b_s +
243
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
237
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
244
238
  (tl.arange(0, 1) * s_c_s)[None, :])
245
- blk_s_msk = ((blk_s_idx >= 0 and
246
- blk_s_idx < s_b * s_b_s) and
247
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
248
- tl.arange(0, 1)[None, :] < val_tbs))
239
+ blk_s_msk = (blk_s_idx >= 0 and
240
+ blk_s_idx < s_b * s_b_s)
249
241
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
250
242
 
251
243
  blk_g_idx = ((pid_blk * g_b_s) +
252
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
253
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
254
- blk_g_msk = ((blk_g_idx >= 0 and
255
- blk_g_idx < g_b * g_b_s) and
256
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
257
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
244
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
245
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
246
+ blk_g_msk = (blk_g_idx >= 0 and
247
+ blk_g_idx < g_b * g_b_s)
258
248
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
259
249
 
260
250
  blk_x_idx = ((pid_blk * x_b_s) +
261
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
262
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
263
- blk_x_msk = ((blk_x_idx >= 0 and
264
- blk_x_idx < x_b * x_b_s) and
265
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
266
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
251
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
252
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
253
+ blk_x_msk = (blk_x_idx >= 0 and
254
+ blk_x_idx < x_b * x_b_s)
267
255
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
268
256
 
269
257
  buf = blk_x * (blk_g - blk_s)
270
258
 
271
259
  blk_o_idx = ((pid_blk * o_b_s) +
272
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
273
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
274
- blk_o_msk = ((blk_o_idx >= 0 and
275
- blk_o_idx < o_b * o_b_s) and
276
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
277
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
260
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
261
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
262
+ blk_o_msk = (blk_o_idx >= 0 and
263
+ blk_o_idx < o_b * o_b_s)
278
264
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
279
265
 
280
266