blksprs 2.0rc3__py3-none-any.whl → 2.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/flow.py CHANGED
@@ -5,14 +5,15 @@ from torch._library import triton_op
5
5
  from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
- from blksprs.utils.tools import stride, get_autotune_configs
8
+ from blksprs.utils.tools import stride
9
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
10
 
10
11
 
11
12
  @triton_op("blksprs::flow_pull", mutates_args={})
12
13
  def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
13
14
  sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
14
15
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
15
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
16
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
16
17
  dtype=x.dtype, device=x.device)
17
18
 
18
19
  x_b, x_r, x_c = x.size()
@@ -43,7 +44,9 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
43
44
 
44
45
  @triton.autotune(
45
46
  configs=get_autotune_configs(),
46
- key=[],
47
+ key=["sparsity_block_size"],
48
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
49
+ reset_to_zero=["o"]
47
50
  )
48
51
  @triton.jit
49
52
  def flow_pull_kernel(x,
@@ -60,9 +63,6 @@ def flow_pull_kernel(x,
60
63
  pid_row = tl.program_id(axis=1)
61
64
  pid_col = tl.program_id(axis=2)
62
65
 
63
- # Get valid triton block size
64
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
65
-
66
66
  # Get sparsity index of current output block consisting of its batch, row, and column index
67
67
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
68
68
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -85,21 +85,17 @@ def flow_pull_kernel(x,
85
85
 
86
86
  if rev_idx_spa >= 0:
87
87
  blk_x_idx = (rev_idx_spa * x_b_s +
88
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
89
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
90
- blk_x_msk = ((blk_x_idx >= 0 and
91
- blk_x_idx < x_b * x_b_s) and
92
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
93
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
88
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
89
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
90
+ blk_x_msk = (blk_x_idx >= 0 and
91
+ blk_x_idx < x_b * x_b_s)
94
92
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
95
93
 
96
94
  blk_o_idx = (pid_blk * o_b_s +
97
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
98
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
99
- blk_o_msk = ((blk_o_idx >= 0 and
100
- blk_o_idx < o_b * o_b_s) and
101
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
102
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
95
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
96
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
97
+ blk_o_msk = (blk_o_idx >= 0 and
98
+ blk_o_idx < o_b * o_b_s)
103
99
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
104
100
 
105
101
 
@@ -137,7 +133,8 @@ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor
137
133
 
138
134
  @triton.autotune(
139
135
  configs=get_autotune_configs(),
140
- key=[],
136
+ key=["sparsity_block_size"],
137
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
141
138
  reset_to_zero=["o"]
142
139
  )
143
140
  @triton.jit
@@ -155,9 +152,6 @@ def flow_push_kernel(x,
155
152
  pid_row = tl.program_id(axis=1)
156
153
  pid_col = tl.program_id(axis=2)
157
154
 
158
- # Get valid triton block size
159
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
160
-
161
155
  # Get sparsity index of current input block consisting of its batch, row, and column index
162
156
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
163
157
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -180,19 +174,15 @@ def flow_push_kernel(x,
180
174
 
181
175
  if rev_idx_spa >= 0:
182
176
  blk_x_idx = (pid_blk * x_b_s +
183
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
184
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
185
- blk_x_msk = ((blk_x_idx >= 0 and
186
- blk_x_idx < x_b * x_b_s) and
187
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
188
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
177
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
178
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
179
+ blk_x_msk = (blk_x_idx >= 0 and
180
+ blk_x_idx < x_b * x_b_s)
189
181
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
190
182
 
191
183
  blk_o_idx = (rev_idx_spa * o_b_s +
192
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
193
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
194
- blk_o_msk = ((blk_o_idx >= 0 and
195
- blk_o_idx < o_b * o_b_s) and
196
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
197
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
184
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
185
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
186
+ blk_o_msk = (blk_o_idx >= 0 and
187
+ blk_o_idx < o_b * o_b_s)
198
188
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
blksprs/ops/matmul.py CHANGED
@@ -6,7 +6,8 @@ from triton import language as tl
6
6
 
7
7
  from blksprs.ops.transpose import transpose
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float
12
13
 
@@ -60,7 +61,7 @@ def matmul_forward(x: Tensor, y: Tensor,
60
61
  sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
61
62
  _: Tensor, sparsity_lut_o: Tensor,
62
63
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
63
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
64
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
64
65
  dtype=x.dtype, device=x.device)
65
66
 
66
67
  x_b, x_r, x_c = x.size()
@@ -117,7 +118,9 @@ def matmul_backward(ctx, grad_output):
117
118
 
118
119
  @triton.autotune(
119
120
  configs=get_autotune_configs(),
120
- key=[],
121
+ key=["sparsity_block_size"],
122
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
123
+ reset_to_zero=["o"]
121
124
  )
122
125
  @triton.jit
123
126
  def matmul_kernel(x,
@@ -140,9 +143,6 @@ def matmul_kernel(x,
140
143
  pid_row = tl.program_id(axis=1)
141
144
  pid_col = tl.program_id(axis=2)
142
145
 
143
- # Get valid triton block size
144
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
145
-
146
146
  # Get position of current sparsity block consisting of its batch, row, and column index
147
147
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
148
148
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -160,11 +160,11 @@ def matmul_kernel(x,
160
160
  buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
161
161
 
162
162
  # Slide over triton block sized segments of input tensors
163
- for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, val_tbs)):
163
+ for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
164
164
  # Convert to segment index of sparsity layout
165
- i_seg_spa = (i_seg_tri * val_tbs) // sparsity_block_size
165
+ i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
166
166
  # Calculate the triton segment index within a block
167
- i_seg_tri_mod = i_seg_tri % (sparsity_block_size // val_tbs)
167
+ i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
168
168
 
169
169
  # Get reverse sparsity indices for input tensors x and y
170
170
  # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
@@ -184,23 +184,23 @@ def matmul_kernel(x,
184
184
  # If both blocks are present commence calculation
185
185
  if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
186
186
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
187
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
188
- ((i_seg_tri_mod * val_tbs +
187
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
188
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
189
189
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
190
190
  blk_x_msk = ((blk_x_idx >= 0 and
191
191
  blk_x_idx < x_b * x_b_s) and
192
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
193
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
192
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
193
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
194
194
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
195
195
 
196
196
  blk_y_idx = ((rev_idx_spa_y * y_b_s) +
197
- ((i_seg_tri_mod * val_tbs +
197
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
198
198
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
199
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
199
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
200
200
  blk_y_msk = ((blk_y_idx >= 0 and
201
201
  blk_y_idx < y_b * y_b_s) and
202
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
203
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
202
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
203
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
204
204
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
205
205
 
206
206
  # Perform matrix multiplication
@@ -211,12 +211,12 @@ def matmul_kernel(x,
211
211
 
212
212
  # Store output
213
213
  blk_o_idx = ((pid_blk * o_b_s) +
214
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
215
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
214
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
215
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
216
216
  blk_o_msk = ((blk_o_idx >= 0 and
217
217
  blk_o_idx < o_b * o_b_s) and
218
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
219
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
218
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
219
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
220
220
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
221
221
 
222
222
 
@@ -6,7 +6,8 @@ from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
- from blksprs.utils.tools import stride, get_autotune_configs
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
11
  from blksprs.utils.validation import validate_contiguous, validate_device, \
11
12
  validate_sparsity_block_size
12
13
 
@@ -87,7 +88,8 @@ def broadcast_add_forward(x: Tensor, y: Tensor,
87
88
 
88
89
  @triton.autotune(
89
90
  configs=get_autotune_configs(),
90
- key=[],
91
+ key=["sparsity_block_size"],
92
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
91
93
  reset_to_zero=["o"]
92
94
  )
93
95
  @triton.jit
@@ -105,9 +107,6 @@ def broadcast_add_kernel(x,
105
107
  pid_row = tl.program_id(axis=1)
106
108
  pid_col = tl.program_id(axis=2)
107
109
 
108
- # Get valid triton block size
109
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
110
-
111
110
  # Get position of current sparsity block consisting of its batch, row, and column index
112
111
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
113
112
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -123,20 +122,18 @@ def broadcast_add_kernel(x,
123
122
 
124
123
  # Load x block
125
124
  blk_x_idx = (spa_bat_o * x_b_s +
126
- ((pid_row * val_tbs + spa_row_o * sparsity_block_size +
125
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
127
126
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
128
- blk_x_msk = ((blk_x_idx >= 0 and
129
- blk_x_idx < x_b * x_b_s) and
130
- (tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
127
+ blk_x_msk = (blk_x_idx >= 0 and
128
+ blk_x_idx < x_b * x_b_s)
131
129
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
132
130
 
133
131
  # Load y block
134
132
  blk_y_idx = (spa_bat_o * y_b_s +
135
- ((pid_col * val_tbs + spa_col_o * sparsity_block_size +
133
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
136
134
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
137
- blk_y_msk = ((blk_y_idx >= 0 and
138
- blk_y_idx < y_b * y_b_s) and
139
- (tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
135
+ blk_y_msk = (blk_y_idx >= 0 and
136
+ blk_y_idx < y_b * y_b_s)
140
137
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
141
138
 
142
139
  # Compute sum
@@ -145,10 +142,8 @@ def broadcast_add_kernel(x,
145
142
 
146
143
  # Store result
147
144
  blk_o_idx = ((pid_blk * o_b_s) +
148
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
149
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
150
- blk_o_msk = ((blk_o_idx >= 0 and
151
- blk_o_idx < o_b * o_b_s) and
152
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
153
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
145
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
146
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
147
+ blk_o_msk = (blk_o_idx >= 0 and
148
+ blk_o_idx < o_b * o_b_s)
154
149
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -5,7 +5,8 @@ from torch._library.triton import wrap_triton, triton_op
5
5
  from triton import language as tl
6
6
 
7
7
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import stride, get_autotune_configs
8
+ from blksprs.utils.tools import stride, get_autocast_min_val
9
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
10
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
10
11
  validate_sparsity_block_size
11
12
 
@@ -96,7 +97,8 @@ def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
96
97
 
97
98
  @triton.autotune(
98
99
  configs=get_autotune_configs(),
99
- key=[],
100
+ key=["sparsity_block_size"],
101
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
100
102
  reset_to_zero=["o"]
101
103
  )
102
104
  @triton.jit
@@ -114,9 +116,6 @@ def row_wise_sum_kernel(x,
114
116
  pid_row = tl.program_id(axis=1)
115
117
  pid_col = tl.program_id(axis=2)
116
118
 
117
- # Get valid triton block size
118
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
119
-
120
119
  # Get position of current sparsity block consisting of its batch and row index
121
120
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
122
121
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -137,23 +136,19 @@ def row_wise_sum_kernel(x,
137
136
  return
138
137
 
139
138
  blk_idx = ((pid_blk * x_b_s) +
140
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
141
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
142
- blk_msk = ((blk_idx >= 0 and
143
- blk_idx < x_b * x_b_s) and
144
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
145
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
139
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
140
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
141
+ blk_msk = (blk_idx >= 0 and
142
+ blk_idx < x_b * x_b_s)
146
143
  blk = tl.load(x + blk_idx, mask=blk_msk)
147
144
 
148
145
  buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
149
146
 
150
147
  o_idx = (rev_idx_spa * o_b_s +
151
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
152
149
  (tl.arange(0, 1))[None, :])
153
- o_msk = ((o_idx >= 0 and
154
- o_idx < o_b * o_b_s) and
155
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
156
- tl.arange(0, 1)[None, :] < val_tbs))
150
+ o_msk = (o_idx >= 0 and
151
+ o_idx < o_b * o_b_s)
157
152
  tl.atomic_add(o + o_idx, buf, o_msk)
158
153
 
159
154
 
@@ -214,7 +209,7 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
214
209
  output = torch.full(size=(n_sparse_blocks_output,
215
210
  sparsity_block_size,
216
211
  1 if flag_slice_only else sparsity_block_size),
217
- fill_value=float("-inf"),
212
+ fill_value=get_autocast_min_val(),
218
213
  device=x.device)
219
214
 
220
215
  x_b, x_r, x_c = x.size()
@@ -245,7 +240,8 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
245
240
 
246
241
  @triton.autotune(
247
242
  configs=get_autotune_configs(),
248
- key=[],
243
+ key=["sparsity_block_size"],
244
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
249
245
  restore_value=["o"]
250
246
  )
251
247
  @triton.jit
@@ -263,9 +259,6 @@ def row_wise_max_kernel(x,
263
259
  pid_row = tl.program_id(axis=1)
264
260
  pid_col = tl.program_id(axis=2)
265
261
 
266
- # Get valid triton block size
267
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
268
-
269
262
  # Get position of current sparsity block consisting of its batch and row index
270
263
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
271
264
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -286,23 +279,19 @@ def row_wise_max_kernel(x,
286
279
  return
287
280
 
288
281
  blk_idx = ((pid_blk * x_b_s) +
289
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
290
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
291
- blk_msk = ((blk_idx >= 0 and
292
- blk_idx < x_b * x_b_s) and
293
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
294
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
282
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
283
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
295
286
  blk = tl.load(x + blk_idx, mask=blk_msk)
296
287
 
297
288
  buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
298
289
 
299
290
  o_idx = (rev_idx_spa * o_b_s +
300
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
291
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
301
292
  (tl.arange(0, 1))[None, :])
302
- o_msk = ((o_idx >= 0 and
303
- o_idx < o_b * o_b_s) and
304
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
305
- tl.arange(0, 1)[None, :] < val_tbs))
293
+ o_msk = (o_idx >= 0 and
294
+ o_idx < o_b * o_b_s)
306
295
  tl.atomic_max(o + o_idx, buf, o_msk)
307
296
 
308
297
 
@@ -354,7 +343,7 @@ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
354
343
  def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
355
344
  sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
356
345
  y: Tensor, sparsity_block_size: int) -> Tensor:
357
- output = torch.empty_like(x)
346
+ output = torch.zeros_like(x)
358
347
 
359
348
  x_b, x_r, x_c = x.size()
360
349
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -371,7 +360,7 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
371
360
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
372
361
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
373
362
 
374
- (kernel_blocksparse_row_wise_add[triton_grid]
363
+ (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
375
364
  (x,
376
365
  x_b, x_b_s, x_r_s, x_c_s,
377
366
  sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
@@ -387,7 +376,9 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
387
376
 
388
377
  @triton.autotune(
389
378
  configs=get_autotune_configs(),
390
- key=[]
379
+ key=["sparsity_block_size"],
380
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
381
+ reset_to_zero=["o"]
391
382
  )
392
383
  @triton.jit
393
384
  def kernel_blocksparse_row_wise_add(x,
@@ -405,9 +396,6 @@ def kernel_blocksparse_row_wise_add(x,
405
396
  pid_row = tl.program_id(axis=1)
406
397
  pid_col = tl.program_id(axis=2)
407
398
 
408
- # Get valid triton block size
409
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
410
-
411
399
  # Get position of current sparsity block consisting of its batch and row index
412
400
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
413
401
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
@@ -429,22 +417,18 @@ def kernel_blocksparse_row_wise_add(x,
429
417
 
430
418
  # Load x block
431
419
  blk_x_idx = ((pid_blk * x_b_s) +
432
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
433
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
434
- blk_x_msk = ((blk_x_idx >= 0 and
435
- blk_x_idx < x_b * x_b_s) and
436
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
437
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
420
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
421
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
422
+ blk_x_msk = (blk_x_idx >= 0 and
423
+ blk_x_idx < x_b * x_b_s)
438
424
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
439
425
 
440
426
  # Load sum block
441
427
  blk_s_idx = (rev_idx_spa_s * y_b_s +
442
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
428
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
443
429
  (tl.arange(0, 1) * y_c_s)[None, :])
444
- blk_s_msk = ((blk_s_idx >= 0 and
445
- blk_s_idx < y_b * y_b_s) and
446
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
447
- tl.arange(0, 1)[None, :] < val_tbs))
430
+ blk_s_msk = (blk_s_idx >= 0 and
431
+ blk_s_idx < y_b * y_b_s)
448
432
  blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
449
433
 
450
434
  # Compute exp
@@ -452,10 +436,8 @@ def kernel_blocksparse_row_wise_add(x,
452
436
 
453
437
  # Store block
454
438
  blk_o_idx = ((pid_blk * o_b_s) +
455
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
456
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
457
- blk_o_msk = ((blk_o_idx >= 0 and
458
- blk_o_idx < o_b * o_b_s) and
459
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
460
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
439
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
440
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
441
+ blk_o_msk = (blk_o_idx >= 0 and
442
+ blk_o_idx < o_b * o_b_s)
461
443
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
blksprs/ops/softmax.py CHANGED
@@ -7,7 +7,8 @@ from triton import language as tl
7
7
 
8
8
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import stride, get_autotune_configs
10
+ from blksprs.utils.tools import stride
11
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
11
12
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
12
13
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
13
14
 
@@ -51,7 +52,7 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
51
52
  sparsity_lut: Tensor,
52
53
  sparsity_reverse_lut_rws: Tensor,
53
54
  sparsity_block_size: int) -> Tensor:
54
- output = torch.empty_like(x)
55
+ output = torch.zeros_like(x)
55
56
 
56
57
  x_b, x_r, x_c = x.size()
57
58
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -108,13 +109,14 @@ def softmax_backward(ctx, grad_output):
108
109
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
109
110
  s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
110
111
 
111
- grad_x = torch.empty_like(o, dtype=torch.float)
112
+ grad_x = torch.zeros_like(o, dtype=torch.float)
112
113
 
113
114
  triton_grid = lambda meta: [o_b,
114
115
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
115
116
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
116
117
 
117
- (wrap_triton(softmax_kernel_grad)[triton_grid]
118
+ # TODO wrap
119
+ (softmax_kernel_grad[triton_grid]
118
120
  (grad_output,
119
121
  o_b, o_b_s, o_r_s, o_c_s,
120
122
  o,
@@ -133,7 +135,9 @@ def softmax_backward(ctx, grad_output):
133
135
 
134
136
  @triton.autotune(
135
137
  configs=get_autotune_configs(),
136
- key=[]
138
+ key=["sparsity_block_size"],
139
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
140
+ reset_to_zero=["o"]
137
141
  )
138
142
  @triton.jit
139
143
  def softmax_kernel(x,
@@ -150,9 +154,6 @@ def softmax_kernel(x,
150
154
  pid_row = tl.program_id(axis=1)
151
155
  pid_col = tl.program_id(axis=2)
152
156
 
153
- # Get valid triton block size
154
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
155
-
156
157
  # Get position of current sparsity block consisting of its batch and row index
157
158
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
158
159
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -171,22 +172,18 @@ def softmax_kernel(x,
171
172
  if rev_idx_spa_s >= 0:
172
173
  # Load x block
173
174
  blk_x_idx = ((pid_blk * x_b_s) +
174
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
175
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
176
- blk_x_msk = ((blk_x_idx >= 0 and
177
- blk_x_idx < x_b * x_b_s) and
178
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
179
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
175
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
176
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
177
+ blk_x_msk = (blk_x_idx >= 0 and
178
+ blk_x_idx < x_b * x_b_s)
180
179
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
181
180
 
182
181
  # Load sum block
183
182
  blk_s_idx = (rev_idx_spa_s * s_b_s +
184
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
183
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
185
184
  (tl.arange(0, 1) * s_c_s)[None, :])
186
- blk_s_msk = ((blk_s_idx >= 0 and
187
- blk_s_idx < s_b * s_b_s) and
188
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
189
- tl.arange(0, 1)[None, :] < val_tbs))
185
+ blk_s_msk = (blk_s_idx >= 0 and
186
+ blk_s_idx < s_b * s_b_s)
190
187
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
191
188
 
192
189
  # Compute softmax
@@ -198,7 +195,9 @@ def softmax_kernel(x,
198
195
 
199
196
  @triton.autotune(
200
197
  configs=get_autotune_configs(),
201
- key=[]
198
+ key=["sparsity_block_size"],
199
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
200
+ reset_to_zero=["o"]
202
201
  )
203
202
  @triton.jit
204
203
  def softmax_kernel_grad(g,
@@ -219,9 +218,6 @@ def softmax_kernel_grad(g,
219
218
  pid_row = tl.program_id(axis=1)
220
219
  pid_col = tl.program_id(axis=2)
221
220
 
222
- # Get valid triton block size
223
- val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
224
-
225
221
  # Get position of current sparsity block consisting of its batch and row index
226
222
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
227
223
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -238,41 +234,33 @@ def softmax_kernel_grad(g,
238
234
 
239
235
  if rev_idx_spa_s >= 0:
240
236
  blk_s_idx = (rev_idx_spa_s * s_b_s +
241
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
237
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
242
238
  (tl.arange(0, 1) * s_c_s)[None, :])
243
- blk_s_msk = ((blk_s_idx >= 0 and
244
- blk_s_idx < s_b * s_b_s) and
245
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
246
- tl.arange(0, 1)[None, :] < val_tbs))
239
+ blk_s_msk = (blk_s_idx >= 0 and
240
+ blk_s_idx < s_b * s_b_s)
247
241
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
248
242
 
249
243
  blk_g_idx = ((pid_blk * g_b_s) +
250
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
251
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
252
- blk_g_msk = ((blk_g_idx >= 0 and
253
- blk_g_idx < g_b * g_b_s) and
254
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
255
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
244
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
245
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
246
+ blk_g_msk = (blk_g_idx >= 0 and
247
+ blk_g_idx < g_b * g_b_s)
256
248
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
257
249
 
258
250
  blk_x_idx = ((pid_blk * x_b_s) +
259
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
260
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
261
- blk_x_msk = ((blk_x_idx >= 0 and
262
- blk_x_idx < x_b * x_b_s) and
263
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
264
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
251
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
252
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
253
+ blk_x_msk = (blk_x_idx >= 0 and
254
+ blk_x_idx < x_b * x_b_s)
265
255
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
266
256
 
267
257
  buf = blk_x * (blk_g - blk_s)
268
258
 
269
259
  blk_o_idx = ((pid_blk * o_b_s) +
270
- ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
271
- ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
272
- blk_o_msk = ((blk_o_idx >= 0 and
273
- blk_o_idx < o_b * o_b_s) and
274
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
275
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
260
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
261
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
262
+ blk_o_msk = (blk_o_idx >= 0 and
263
+ blk_o_idx < o_b * o_b_s)
276
264
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
277
265
 
278
266