blksprs 2.0rc4__py3-none-any.whl → 2.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/flow.py CHANGED
@@ -5,7 +5,8 @@ from torch._library import triton_op
5
5
  from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
- from blksprs.utils.tools import stride, get_autotune_configs
8
+ from blksprs.utils.tools import stride
9
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
10
 
10
11
 
11
12
  @triton_op("blksprs::flow_pull", mutates_args={})
@@ -43,7 +44,8 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
43
44
 
44
45
  @triton.autotune(
45
46
  configs=get_autotune_configs(),
46
- key=[],
47
+ key=["sparsity_block_size"],
48
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
47
49
  reset_to_zero=["o"]
48
50
  )
49
51
  @triton.jit
@@ -61,9 +63,6 @@ def flow_pull_kernel(x,
61
63
  pid_row = tl.program_id(axis=1)
62
64
  pid_col = tl.program_id(axis=2)
63
65
 
64
- # Get valid triton block size
65
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
66
-
67
66
  # Get sparsity index of current output block consisting of its batch, row, and column index
68
67
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
69
68
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -86,21 +85,17 @@ def flow_pull_kernel(x,
86
85
 
87
86
  if rev_idx_spa >= 0:
88
87
  blk_x_idx = (rev_idx_spa * x_b_s +
89
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
90
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
91
- blk_x_msk = ((blk_x_idx >= 0 and
92
- blk_x_idx < x_b * x_b_s) and
93
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
94
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
88
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
89
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
90
+ blk_x_msk = (blk_x_idx >= 0 and
91
+ blk_x_idx < x_b * x_b_s)
95
92
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
96
93
 
97
94
  blk_o_idx = (pid_blk * o_b_s +
98
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
99
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
100
- blk_o_msk = ((blk_o_idx >= 0 and
101
- blk_o_idx < o_b * o_b_s) and
102
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
103
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
95
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
96
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
97
+ blk_o_msk = (blk_o_idx >= 0 and
98
+ blk_o_idx < o_b * o_b_s)
104
99
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
105
100
 
106
101
 
@@ -138,7 +133,8 @@ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor
138
133
 
139
134
  @triton.autotune(
140
135
  configs=get_autotune_configs(),
141
- key=[],
136
+ key=["sparsity_block_size"],
137
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
142
138
  reset_to_zero=["o"]
143
139
  )
144
140
  @triton.jit
@@ -156,9 +152,6 @@ def flow_push_kernel(x,
156
152
  pid_row = tl.program_id(axis=1)
157
153
  pid_col = tl.program_id(axis=2)
158
154
 
159
- # Get valid triton block size
160
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
161
-
162
155
  # Get sparsity index of current input block consisting of its batch, row, and column index
163
156
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
164
157
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -181,19 +174,15 @@ def flow_push_kernel(x,
181
174
 
182
175
  if rev_idx_spa >= 0:
183
176
  blk_x_idx = (pid_blk * x_b_s +
184
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
185
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
186
- blk_x_msk = ((blk_x_idx >= 0 and
187
- blk_x_idx < x_b * x_b_s) and
188
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
189
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
177
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
178
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
179
+ blk_x_msk = (blk_x_idx >= 0 and
180
+ blk_x_idx < x_b * x_b_s)
190
181
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
191
182
 
192
183
  blk_o_idx = (rev_idx_spa * o_b_s +
193
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
194
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
195
- blk_o_msk = ((blk_o_idx >= 0 and
196
- blk_o_idx < o_b * o_b_s) and
197
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
198
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
184
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
185
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
186
+ blk_o_msk = (blk_o_idx >= 0 and
187
+ blk_o_idx < o_b * o_b_s)
199
188
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
blksprs/ops/matmul.py CHANGED
@@ -6,7 +6,8 @@ from triton import language as tl
6
6
 
7
7
  from blksprs.ops.transpose import transpose
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float
12
13
 
@@ -117,7 +118,8 @@ def matmul_backward(ctx, grad_output):
117
118
 
118
119
  @triton.autotune(
119
120
  configs=get_autotune_configs(),
120
- key=[],
121
+ key=["sparsity_block_size"],
122
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
121
123
  reset_to_zero=["o"]
122
124
  )
123
125
  @triton.jit
@@ -141,9 +143,6 @@ def matmul_kernel(x,
141
143
  pid_row = tl.program_id(axis=1)
142
144
  pid_col = tl.program_id(axis=2)
143
145
 
144
- # Get valid triton block size
145
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
146
-
147
146
  # Get position of current sparsity block consisting of its batch, row, and column index
148
147
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
149
148
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -161,11 +160,11 @@ def matmul_kernel(x,
161
160
  buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
162
161
 
163
162
  # Slide over triton block sized segments of input tensors
164
- for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, val_tbs)):
163
+ for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
165
164
  # Convert to segment index of sparsity layout
166
- i_seg_spa = (i_seg_tri * val_tbs) // sparsity_block_size
165
+ i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
167
166
  # Calculate the triton segment index within a block
168
- i_seg_tri_mod = i_seg_tri % (sparsity_block_size // val_tbs)
167
+ i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
169
168
 
170
169
  # Get reverse sparsity indices for input tensors x and y
171
170
  # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
@@ -185,23 +184,23 @@ def matmul_kernel(x,
185
184
  # If both blocks are present commence calculation
186
185
  if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
187
186
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
188
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
189
- ((i_seg_tri_mod * val_tbs +
187
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
188
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
190
189
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
191
190
  blk_x_msk = ((blk_x_idx >= 0 and
192
191
  blk_x_idx < x_b * x_b_s) and
193
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
194
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
192
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
193
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
195
194
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
196
195
 
197
196
  blk_y_idx = ((rev_idx_spa_y * y_b_s) +
198
- ((i_seg_tri_mod * val_tbs +
197
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
199
198
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
200
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
199
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
201
200
  blk_y_msk = ((blk_y_idx >= 0 and
202
201
  blk_y_idx < y_b * y_b_s) and
203
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
204
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
202
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
203
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
205
204
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
206
205
 
207
206
  # Perform matrix multiplication
@@ -212,12 +211,12 @@ def matmul_kernel(x,
212
211
 
213
212
  # Store output
214
213
  blk_o_idx = ((pid_blk * o_b_s) +
215
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
216
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
214
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
215
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
217
216
  blk_o_msk = ((blk_o_idx >= 0 and
218
217
  blk_o_idx < o_b * o_b_s) and
219
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
220
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
218
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
219
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
221
220
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
222
221
 
223
222
 
@@ -6,11 +6,13 @@ from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
11
  from blksprs.utils.validation import validate_contiguous, validate_device, \
11
12
  validate_sparsity_block_size
12
13
 
13
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
14
16
  def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
15
17
  sparsity_block_size: int) -> BlksprsTensor:
16
18
  """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
@@ -87,7 +89,8 @@ def broadcast_add_forward(x: Tensor, y: Tensor,
87
89
 
88
90
  @triton.autotune(
89
91
  configs=get_autotune_configs(),
90
- key=[],
92
+ key=["sparsity_block_size"],
93
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
91
94
  reset_to_zero=["o"]
92
95
  )
93
96
  @triton.jit
@@ -105,9 +108,6 @@ def broadcast_add_kernel(x,
105
108
  pid_row = tl.program_id(axis=1)
106
109
  pid_col = tl.program_id(axis=2)
107
110
 
108
- # Get valid triton block size
109
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
110
-
111
111
  # Get position of current sparsity block consisting of its batch, row, and column index
112
112
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
113
113
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -123,20 +123,18 @@ def broadcast_add_kernel(x,
123
123
 
124
124
  # Load x block
125
125
  blk_x_idx = (spa_bat_o * x_b_s +
126
- ((pid_row * val_tbs + spa_row_o * sparsity_block_size +
126
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
127
127
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
128
- blk_x_msk = ((blk_x_idx >= 0 and
129
- blk_x_idx < x_b * x_b_s) and
130
- (tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
128
+ blk_x_msk = (blk_x_idx >= 0 and
129
+ blk_x_idx < x_b * x_b_s)
131
130
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
132
131
 
133
132
  # Load y block
134
133
  blk_y_idx = (spa_bat_o * y_b_s +
135
- ((pid_col * val_tbs + spa_col_o * sparsity_block_size +
134
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
136
135
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
137
- blk_y_msk = ((blk_y_idx >= 0 and
138
- blk_y_idx < y_b * y_b_s) and
139
- (tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
136
+ blk_y_msk = (blk_y_idx >= 0 and
137
+ blk_y_idx < y_b * y_b_s)
140
138
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
141
139
 
142
140
  # Compute sum
@@ -145,10 +143,8 @@ def broadcast_add_kernel(x,
145
143
 
146
144
  # Store result
147
145
  blk_o_idx = ((pid_blk * o_b_s) +
148
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
149
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
150
- blk_o_msk = ((blk_o_idx >= 0 and
151
- blk_o_idx < o_b * o_b_s) and
152
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
153
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
146
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
147
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
148
+ blk_o_msk = (blk_o_idx >= 0 and
149
+ blk_o_idx < o_b * o_b_s)
154
150
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -4,8 +4,9 @@ from torch import Tensor
4
4
  from torch._library.triton import wrap_triton, triton_op
5
5
  from triton import language as tl
6
6
 
7
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
9
10
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
10
11
  validate_sparsity_block_size
11
12
 
@@ -94,9 +95,11 @@ def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
94
95
  return output
95
96
 
96
97
 
98
+ # noinspection PyUnusedLocal
97
99
  @triton.autotune(
98
100
  configs=get_autotune_configs(),
99
- key=[],
101
+ key=["sparsity_block_size"],
102
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
100
103
  reset_to_zero=["o"]
101
104
  )
102
105
  @triton.jit
@@ -114,9 +117,6 @@ def row_wise_sum_kernel(x,
114
117
  pid_row = tl.program_id(axis=1)
115
118
  pid_col = tl.program_id(axis=2)
116
119
 
117
- # Get valid triton block size
118
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
119
-
120
120
  # Get position of current sparsity block consisting of its batch and row index
121
121
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
122
122
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -137,23 +137,19 @@ def row_wise_sum_kernel(x,
137
137
  return
138
138
 
139
139
  blk_idx = ((pid_blk * x_b_s) +
140
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
141
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
142
- blk_msk = ((blk_idx >= 0 and
143
- blk_idx < x_b * x_b_s) and
144
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
145
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
140
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
141
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
142
+ blk_msk = (blk_idx >= 0 and
143
+ blk_idx < x_b * x_b_s)
146
144
  blk = tl.load(x + blk_idx, mask=blk_msk)
147
145
 
148
146
  buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
149
147
 
150
148
  o_idx = (rev_idx_spa * o_b_s +
151
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
149
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
152
150
  (tl.arange(0, 1))[None, :])
153
- o_msk = ((o_idx >= 0 and
154
- o_idx < o_b * o_b_s) and
155
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
156
- tl.arange(0, 1)[None, :] < val_tbs))
151
+ o_msk = (o_idx >= 0 and
152
+ o_idx < o_b * o_b_s)
157
153
  tl.atomic_add(o + o_idx, buf, o_msk)
158
154
 
159
155
 
@@ -180,6 +176,8 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
180
176
  of the input and the sparsity layout of the output tensor.
181
177
 
182
178
  """
179
+ # TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376
180
+ x = torch.where(x == -0.0, torch.tensor(0.0), x)
183
181
  x = x.contiguous()
184
182
 
185
183
  validate_dimensions(x)
@@ -214,7 +212,7 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
214
212
  output = torch.full(size=(n_sparse_blocks_output,
215
213
  sparsity_block_size,
216
214
  1 if flag_slice_only else sparsity_block_size),
217
- fill_value=float("-inf"),
215
+ fill_value=torch.finfo(x.dtype).min,
218
216
  device=x.device)
219
217
 
220
218
  x_b, x_r, x_c = x.size()
@@ -243,9 +241,11 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
243
241
  return output
244
242
 
245
243
 
244
+ # noinspection PyUnusedLocal
246
245
  @triton.autotune(
247
246
  configs=get_autotune_configs(),
248
- key=[],
247
+ key=["sparsity_block_size"],
248
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
249
249
  restore_value=["o"]
250
250
  )
251
251
  @triton.jit
@@ -263,9 +263,6 @@ def row_wise_max_kernel(x,
263
263
  pid_row = tl.program_id(axis=1)
264
264
  pid_col = tl.program_id(axis=2)
265
265
 
266
- # Get valid triton block size
267
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
268
-
269
266
  # Get position of current sparsity block consisting of its batch and row index
270
267
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
271
268
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -286,23 +283,19 @@ def row_wise_max_kernel(x,
286
283
  return
287
284
 
288
285
  blk_idx = ((pid_blk * x_b_s) +
289
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
290
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
291
- blk_msk = ((blk_idx >= 0 and
292
- blk_idx < x_b * x_b_s) and
293
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
294
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
286
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
287
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
288
+ blk_msk = (blk_idx >= 0 and
289
+ blk_idx < x_b * x_b_s)
295
290
  blk = tl.load(x + blk_idx, mask=blk_msk)
296
291
 
297
292
  buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
298
293
 
299
294
  o_idx = (rev_idx_spa * o_b_s +
300
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
295
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
301
296
  (tl.arange(0, 1))[None, :])
302
- o_msk = ((o_idx >= 0 and
303
- o_idx < o_b * o_b_s) and
304
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
305
- tl.arange(0, 1)[None, :] < val_tbs))
297
+ o_msk = (o_idx >= 0 and
298
+ o_idx < o_b * o_b_s)
306
299
  tl.atomic_max(o + o_idx, buf, o_msk)
307
300
 
308
301
 
@@ -371,7 +364,7 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
371
364
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
372
365
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
373
366
 
374
- (kernel_blocksparse_row_wise_add[triton_grid]
367
+ (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
375
368
  (x,
376
369
  x_b, x_b_s, x_r_s, x_c_s,
377
370
  sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
@@ -387,7 +380,8 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
387
380
 
388
381
  @triton.autotune(
389
382
  configs=get_autotune_configs(),
390
- key=[],
383
+ key=["sparsity_block_size"],
384
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
391
385
  reset_to_zero=["o"]
392
386
  )
393
387
  @triton.jit
@@ -406,9 +400,6 @@ def kernel_blocksparse_row_wise_add(x,
406
400
  pid_row = tl.program_id(axis=1)
407
401
  pid_col = tl.program_id(axis=2)
408
402
 
409
- # Get valid triton block size
410
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
411
-
412
403
  # Get position of current sparsity block consisting of its batch and row index
413
404
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
414
405
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -430,22 +421,18 @@ def kernel_blocksparse_row_wise_add(x,
430
421
 
431
422
  # Load x block
432
423
  blk_x_idx = ((pid_blk * x_b_s) +
433
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
434
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
435
- blk_x_msk = ((blk_x_idx >= 0 and
436
- blk_x_idx < x_b * x_b_s) and
437
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
438
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
424
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
425
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
426
+ blk_x_msk = (blk_x_idx >= 0 and
427
+ blk_x_idx < x_b * x_b_s)
439
428
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
440
429
 
441
430
  # Load sum block
442
431
  blk_s_idx = (rev_idx_spa_s * y_b_s +
443
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
432
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
444
433
  (tl.arange(0, 1) * y_c_s)[None, :])
445
- blk_s_msk = ((blk_s_idx >= 0 and
446
- blk_s_idx < y_b * y_b_s) and
447
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
448
- tl.arange(0, 1)[None, :] < val_tbs))
434
+ blk_s_msk = (blk_s_idx >= 0 and
435
+ blk_s_idx < y_b * y_b_s)
449
436
  blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
450
437
 
451
438
  # Compute exp
@@ -453,10 +440,8 @@ def kernel_blocksparse_row_wise_add(x,
453
440
 
454
441
  # Store block
455
442
  blk_o_idx = ((pid_blk * o_b_s) +
456
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
457
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
458
- blk_o_msk = ((blk_o_idx >= 0 and
459
- blk_o_idx < o_b * o_b_s) and
460
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
461
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
443
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
444
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
445
+ blk_o_msk = (blk_o_idx >= 0 and
446
+ blk_o_idx < o_b * o_b_s)
462
447
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
blksprs/ops/softmax.py CHANGED
@@ -7,7 +7,8 @@ from triton import language as tl
7
7
 
8
8
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import stride, get_autotune_configs
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
11
12
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
12
13
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
13
14
 
@@ -114,7 +115,8 @@ def softmax_backward(ctx, grad_output):
114
115
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
115
116
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
116
117
 
117
- (wrap_triton(softmax_kernel_grad)[triton_grid]
118
+ # TODO wrap
119
+ (softmax_kernel_grad[triton_grid]
118
120
  (grad_output,
119
121
  o_b, o_b_s, o_r_s, o_c_s,
120
122
  o,
@@ -133,7 +135,8 @@ def softmax_backward(ctx, grad_output):
133
135
 
134
136
  @triton.autotune(
135
137
  configs=get_autotune_configs(),
136
- key=[],
138
+ key=["sparsity_block_size"],
139
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
137
140
  reset_to_zero=["o"]
138
141
  )
139
142
  @triton.jit
@@ -151,9 +154,6 @@ def softmax_kernel(x,
151
154
  pid_row = tl.program_id(axis=1)
152
155
  pid_col = tl.program_id(axis=2)
153
156
 
154
- # Get valid triton block size
155
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
156
-
157
157
  # Get position of current sparsity block consisting of its batch and row index
158
158
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
159
159
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -172,22 +172,18 @@ def softmax_kernel(x,
172
172
  if rev_idx_spa_s >= 0:
173
173
  # Load x block
174
174
  blk_x_idx = ((pid_blk * x_b_s) +
175
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
176
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
177
- blk_x_msk = ((blk_x_idx >= 0 and
178
- blk_x_idx < x_b * x_b_s) and
179
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
180
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
175
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
176
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
177
+ blk_x_msk = (blk_x_idx >= 0 and
178
+ blk_x_idx < x_b * x_b_s)
181
179
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
182
180
 
183
181
  # Load sum block
184
182
  blk_s_idx = (rev_idx_spa_s * s_b_s +
185
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
183
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
186
184
  (tl.arange(0, 1) * s_c_s)[None, :])
187
- blk_s_msk = ((blk_s_idx >= 0 and
188
- blk_s_idx < s_b * s_b_s) and
189
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
190
- tl.arange(0, 1)[None, :] < val_tbs))
185
+ blk_s_msk = (blk_s_idx >= 0 and
186
+ blk_s_idx < s_b * s_b_s)
191
187
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
192
188
 
193
189
  # Compute softmax
@@ -199,7 +195,8 @@ def softmax_kernel(x,
199
195
 
200
196
  @triton.autotune(
201
197
  configs=get_autotune_configs(),
202
- key=[],
198
+ key=["sparsity_block_size"],
199
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
203
200
  reset_to_zero=["o"]
204
201
  )
205
202
  @triton.jit
@@ -221,9 +218,6 @@ def softmax_kernel_grad(g,
221
218
  pid_row = tl.program_id(axis=1)
222
219
  pid_col = tl.program_id(axis=2)
223
220
 
224
- # Get valid triton block size
225
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
226
-
227
221
  # Get position of current sparsity block consisting of its batch and row index
228
222
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
229
223
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -240,41 +234,33 @@ def softmax_kernel_grad(g,
240
234
 
241
235
  if rev_idx_spa_s >= 0:
242
236
  blk_s_idx = (rev_idx_spa_s * s_b_s +
243
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
237
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
244
238
  (tl.arange(0, 1) * s_c_s)[None, :])
245
- blk_s_msk = ((blk_s_idx >= 0 and
246
- blk_s_idx < s_b * s_b_s) and
247
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
248
- tl.arange(0, 1)[None, :] < val_tbs))
239
+ blk_s_msk = (blk_s_idx >= 0 and
240
+ blk_s_idx < s_b * s_b_s)
249
241
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
250
242
 
251
243
  blk_g_idx = ((pid_blk * g_b_s) +
252
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
253
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
254
- blk_g_msk = ((blk_g_idx >= 0 and
255
- blk_g_idx < g_b * g_b_s) and
256
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
257
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
244
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
245
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
246
+ blk_g_msk = (blk_g_idx >= 0 and
247
+ blk_g_idx < g_b * g_b_s)
258
248
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
259
249
 
260
250
  blk_x_idx = ((pid_blk * x_b_s) +
261
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
262
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
263
- blk_x_msk = ((blk_x_idx >= 0 and
264
- blk_x_idx < x_b * x_b_s) and
265
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
266
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
251
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
252
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
253
+ blk_x_msk = (blk_x_idx >= 0 and
254
+ blk_x_idx < x_b * x_b_s)
267
255
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
268
256
 
269
257
  buf = blk_x * (blk_g - blk_s)
270
258
 
271
259
  blk_o_idx = ((pid_blk * o_b_s) +
272
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
273
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
274
- blk_o_msk = ((blk_o_idx >= 0 and
275
- blk_o_idx < o_b * o_b_s) and
276
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
277
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
260
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
261
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
262
+ blk_o_msk = (blk_o_idx >= 0 and
263
+ blk_o_idx < o_b * o_b_s)
278
264
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
279
265
 
280
266