blksprs 2.0rc7__py3-none-any.whl → 2.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +1 -0
- blksprs/layouting/distribution_layout.py +39 -26
- blksprs/layouting/sparsity_layout.py +58 -45
- blksprs/ops/conversion.py +86 -84
- blksprs/ops/distribution.py +80 -78
- blksprs/ops/flow.py +64 -60
- blksprs/ops/matmul.py +50 -55
- blksprs/ops/misc/broadcast_ops.py +28 -27
- blksprs/ops/misc/row_wise.py +123 -125
- blksprs/ops/partitioning.py +12 -10
- blksprs/ops/repeat.py +6 -5
- blksprs/ops/softmax.py +55 -47
- blksprs/ops/transpose.py +8 -7
- blksprs/utils/autotuning.py +10 -10
- blksprs/utils/processing.py +0 -1
- blksprs/utils/tools.py +8 -0
- {blksprs-2.0rc7.dist-info → blksprs-2.0rc8.dist-info}/METADATA +1 -1
- blksprs-2.0rc8.dist-info/RECORD +23 -0
- {blksprs-2.0rc7.dist-info → blksprs-2.0rc8.dist-info}/WHEEL +1 -1
- blksprs-2.0rc7.dist-info/RECORD +0 -23
- {blksprs-2.0rc7.dist-info → blksprs-2.0rc8.dist-info}/top_level.txt +0 -0
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -60,39 +60,40 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
60
60
|
sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
@triton_op("blksprs::
|
|
63
|
+
@triton_op("blksprs::row_wise_sum_forward", mutates_args={})
|
|
64
64
|
def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
|
|
65
65
|
sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
66
66
|
sparsity_block_size: int, n_sparse_blocks_output: int,
|
|
67
67
|
flag_slice_only: bool = False) -> Tensor:
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
68
|
+
with torch.no_grad():
|
|
69
|
+
output = torch.zeros(
|
|
70
|
+
size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
|
|
71
|
+
dtype=x.dtype, device=x.device)
|
|
72
|
+
|
|
73
|
+
x_b, x_r, x_c = x.size()
|
|
74
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
75
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
76
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
77
|
+
o_b, o_r, o_c = output.size()
|
|
78
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
79
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
80
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
81
|
+
|
|
82
|
+
triton_grid = lambda meta: [x_b,
|
|
83
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
84
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
85
|
+
|
|
86
|
+
(wrap_triton(row_wise_sum_kernel)[triton_grid]
|
|
87
|
+
(x,
|
|
88
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
89
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
90
|
+
output,
|
|
91
|
+
o_b, o_b_s, o_r_s,
|
|
92
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
93
|
+
sparsity_reverse_lut_output,
|
|
94
|
+
sparsity_block_size))
|
|
95
|
+
|
|
96
|
+
return output
|
|
96
97
|
|
|
97
98
|
|
|
98
99
|
# noinspection PyUnusedLocal
|
|
@@ -132,25 +133,22 @@ def row_wise_sum_kernel(x,
|
|
|
132
133
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
133
134
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
134
135
|
|
|
135
|
-
if rev_idx_spa
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
blk_msk = (blk_idx >= 0 and
|
|
143
|
-
blk_idx < x_b * x_b_s)
|
|
144
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
136
|
+
if rev_idx_spa >= 0:
|
|
137
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
138
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
139
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
140
|
+
blk_msk = (blk_idx >= 0 and
|
|
141
|
+
blk_idx < x_b * x_b_s)
|
|
142
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
145
143
|
|
|
146
|
-
|
|
144
|
+
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
147
145
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
146
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
147
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
148
|
+
(tl.arange(0, 1))[None, :])
|
|
149
|
+
o_msk = (o_idx >= 0 and
|
|
150
|
+
o_idx < o_b * o_b_s)
|
|
151
|
+
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
154
152
|
|
|
155
153
|
|
|
156
154
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
@@ -176,7 +174,7 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
176
174
|
of the input and the sparsity layout of the output tensor.
|
|
177
175
|
|
|
178
176
|
"""
|
|
179
|
-
# TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376
|
|
177
|
+
# TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
|
|
180
178
|
x = torch.where(x == -0.0, torch.tensor(0.0), x)
|
|
181
179
|
x = x.contiguous()
|
|
182
180
|
|
|
@@ -204,41 +202,42 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
204
202
|
n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
205
203
|
|
|
206
204
|
|
|
207
|
-
@triton_op("blksprs::
|
|
205
|
+
@triton_op("blksprs::row_wise_max_forward", mutates_args={})
|
|
208
206
|
def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
209
207
|
sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
210
208
|
sparsity_block_size: int, n_sparse_blocks_output: int,
|
|
211
209
|
flag_slice_only: bool = False) -> Tensor:
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
210
|
+
with torch.no_grad():
|
|
211
|
+
output = torch.full(size=(n_sparse_blocks_output,
|
|
212
|
+
sparsity_block_size,
|
|
213
|
+
1 if flag_slice_only else sparsity_block_size),
|
|
214
|
+
fill_value=torch.finfo(x.dtype).min,
|
|
215
|
+
device=x.device)
|
|
216
|
+
|
|
217
|
+
x_b, x_r, x_c = x.size()
|
|
218
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
219
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
220
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
221
|
+
o_b, o_r, o_c = output.size()
|
|
222
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
223
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
224
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
225
|
+
|
|
226
|
+
triton_grid = lambda meta: [x_b,
|
|
227
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
228
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
229
|
+
|
|
230
|
+
(wrap_triton(row_wise_max_kernel)[triton_grid]
|
|
231
|
+
(x,
|
|
232
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
233
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
234
|
+
output,
|
|
235
|
+
o_b, o_b_s, o_r_s,
|
|
236
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
237
|
+
sparsity_reverse_lut_output,
|
|
238
|
+
sparsity_block_size))
|
|
239
|
+
|
|
240
|
+
return output
|
|
242
241
|
|
|
243
242
|
|
|
244
243
|
# noinspection PyUnusedLocal
|
|
@@ -278,25 +277,22 @@ def row_wise_max_kernel(x,
|
|
|
278
277
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
279
278
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
280
279
|
|
|
281
|
-
if rev_idx_spa
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
blk_msk = (blk_idx >= 0 and
|
|
289
|
-
blk_idx < x_b * x_b_s)
|
|
290
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
280
|
+
if rev_idx_spa >= 0:
|
|
281
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
282
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
283
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
284
|
+
blk_msk = (blk_idx >= 0 and
|
|
285
|
+
blk_idx < x_b * x_b_s)
|
|
286
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
291
287
|
|
|
292
|
-
|
|
288
|
+
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
293
289
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
290
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
291
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
292
|
+
(tl.arange(0, 1))[None, :])
|
|
293
|
+
o_msk = (o_idx >= 0 and
|
|
294
|
+
o_idx < o_b * o_b_s)
|
|
295
|
+
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
300
296
|
|
|
301
297
|
|
|
302
298
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
@@ -343,41 +339,43 @@ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
343
339
|
return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size)
|
|
344
340
|
|
|
345
341
|
|
|
346
|
-
@triton_op("blksprs::
|
|
342
|
+
@triton_op("blksprs::row_wise_add_forward", mutates_args={})
|
|
347
343
|
def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
348
344
|
sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
|
|
349
345
|
y: Tensor, sparsity_block_size: int) -> Tensor:
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
346
|
+
with torch.no_grad():
|
|
347
|
+
output = torch.zeros_like(x)
|
|
348
|
+
|
|
349
|
+
x_b, x_r, x_c = x.size()
|
|
350
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
351
|
+
s_lut_r, s_lut_c = sparsity_lut_x.size()
|
|
352
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
|
|
353
|
+
y_b, y_r, y_c = y.size()
|
|
354
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
355
|
+
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
|
|
356
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
|
|
357
|
+
o_b, o_r, o_c = output.size()
|
|
358
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
359
|
+
|
|
360
|
+
triton_grid = lambda meta: [o_b,
|
|
361
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
362
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
363
|
+
|
|
364
|
+
(wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
|
|
365
|
+
(x,
|
|
366
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
367
|
+
sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
368
|
+
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
369
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
370
|
+
sparsity_reverse_x_lut_rwm,
|
|
371
|
+
output,
|
|
372
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
373
|
+
sparsity_block_size))
|
|
374
|
+
|
|
375
|
+
return output
|
|
379
376
|
|
|
380
377
|
|
|
378
|
+
# noinspection PyUnusedLocal
|
|
381
379
|
@triton.autotune(
|
|
382
380
|
configs=get_autotune_configs(),
|
|
383
381
|
key=["sparsity_block_size"],
|
blksprs/ops/partitioning.py
CHANGED
|
@@ -46,14 +46,15 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
46
46
|
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
@triton_op("blksprs::
|
|
49
|
+
@triton_op("blksprs::split_forward", mutates_args={})
|
|
50
50
|
def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
51
51
|
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
with torch.no_grad():
|
|
53
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
54
|
+
n_sparse_blocks)
|
|
54
55
|
|
|
55
56
|
|
|
56
|
-
def
|
|
57
|
+
def split_wrapper_backward(ctx, grad_output):
|
|
57
58
|
sparsity_layout = ctx.saved_tensors[0]
|
|
58
59
|
num_partitions = ctx.num_partitions
|
|
59
60
|
dim = ctx.dim
|
|
@@ -109,7 +110,7 @@ def split_setup_context(ctx, inputs, output):
|
|
|
109
110
|
ctx.sparsity_block_size = sparsity_block_size
|
|
110
111
|
|
|
111
112
|
|
|
112
|
-
split_forward.register_autograd(
|
|
113
|
+
split_forward.register_autograd(split_wrapper_backward, setup_context=split_setup_context)
|
|
113
114
|
|
|
114
115
|
|
|
115
116
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
@@ -150,14 +151,15 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
150
151
|
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
151
152
|
|
|
152
153
|
|
|
153
|
-
@triton_op("blksprs::
|
|
154
|
+
@triton_op("blksprs::merge_forward", mutates_args={})
|
|
154
155
|
def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
155
156
|
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
156
|
-
|
|
157
|
-
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
159
|
+
n_sparse_blocks)
|
|
158
160
|
|
|
159
161
|
|
|
160
|
-
def
|
|
162
|
+
def merge_wrapper_backward(ctx, grad_output):
|
|
161
163
|
sparsity_layout = ctx.saved_tensors[0]
|
|
162
164
|
num_partitions = ctx.num_partitions
|
|
163
165
|
dim = ctx.dim
|
|
@@ -216,4 +218,4 @@ def merge_setup_context(ctx, inputs, output):
|
|
|
216
218
|
ctx.sparsity_block_size = sparsity_block_size
|
|
217
219
|
|
|
218
220
|
|
|
219
|
-
merge_forward.register_autograd(
|
|
221
|
+
merge_forward.register_autograd(merge_wrapper_backward, setup_context=merge_setup_context)
|
blksprs/ops/repeat.py
CHANGED
|
@@ -92,15 +92,16 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
92
92
|
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
@triton_op("blksprs::
|
|
95
|
+
@triton_op("blksprs::repeat_forward", mutates_args={})
|
|
96
96
|
def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
|
|
97
97
|
sparsity_reverse_lut: Tensor,
|
|
98
98
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
99
|
-
|
|
100
|
-
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
101
|
+
n_sparse_blocks)
|
|
101
102
|
|
|
102
103
|
|
|
103
|
-
def
|
|
104
|
+
def repeat_wrapper_backward(ctx, grad_output):
|
|
104
105
|
sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
|
|
105
106
|
sparsity_block_size = ctx.sparsity_block_size
|
|
106
107
|
n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
|
|
@@ -190,4 +191,4 @@ def repeat_setup_context(ctx, inputs, output):
|
|
|
190
191
|
ctx.sparsity_block_size = sparsity_block_size
|
|
191
192
|
|
|
192
193
|
|
|
193
|
-
repeat_forward.register_autograd(
|
|
194
|
+
repeat_forward.register_autograd(repeat_wrapper_backward, setup_context=repeat_setup_context)
|
blksprs/ops/softmax.py
CHANGED
|
@@ -47,19 +47,13 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
47
47
|
sparsity_block_size))
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
@triton_op("blksprs::
|
|
50
|
+
@triton_op("blksprs::softmax_forward", mutates_args={})
|
|
51
51
|
def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
52
52
|
sparsity_lut: Tensor,
|
|
53
53
|
sparsity_reverse_lut_rws: Tensor,
|
|
54
54
|
sparsity_block_size: int) -> Tensor:
|
|
55
55
|
output = torch.zeros_like(x)
|
|
56
56
|
|
|
57
|
-
x_b, x_r, x_c = x.size()
|
|
58
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
59
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
60
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
61
|
-
o_b, o_r, o_c = output.size()
|
|
62
|
-
|
|
63
57
|
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
64
58
|
flag_slice_only=True)
|
|
65
59
|
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
|
|
@@ -67,6 +61,11 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
67
61
|
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
68
62
|
flag_slice_only=True)
|
|
69
63
|
|
|
64
|
+
x_b, x_r, x_c = x.size()
|
|
65
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
66
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
67
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
68
|
+
o_b, o_r, o_c = output.size()
|
|
70
69
|
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
71
70
|
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
72
71
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
@@ -89,50 +88,58 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
89
88
|
return output
|
|
90
89
|
|
|
91
90
|
|
|
92
|
-
def
|
|
91
|
+
def softmax_backward_wrapper(ctx, grad_output):
|
|
93
92
|
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
94
93
|
sparsity_block_size = ctx.sparsity_block_size
|
|
95
94
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
95
|
+
return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
|
|
96
|
+
sparsity_block_size), None, None, None, None, None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@triton_op("blksprs::softmax_backward", mutates_args={})
|
|
100
|
+
def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
|
|
101
|
+
sparsity_block_size: int) -> Tensor:
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
|
|
104
|
+
|
|
105
|
+
sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
|
|
106
|
+
sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
|
|
107
|
+
(sparsity_layout_s_flat == 1) -
|
|
108
|
+
(1 * (sparsity_layout_s_flat == 0)))
|
|
109
|
+
|
|
110
|
+
o_b, o_r, o_c = o.size()
|
|
111
|
+
o_b_s, o_r_s, o_c_s = stride(o)
|
|
112
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
113
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
114
|
+
s_b, s_r, s_c = s.size()
|
|
115
|
+
s_b_s, s_r_s, s_c_s = stride(s)
|
|
116
|
+
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
117
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
118
|
+
|
|
119
|
+
grad_x = torch.zeros_like(o, dtype=torch.float)
|
|
120
|
+
|
|
121
|
+
triton_grid = lambda meta: [o_b,
|
|
122
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
123
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
124
|
+
|
|
125
|
+
(wrap_triton(softmax_kernel_grad)[triton_grid]
|
|
126
|
+
(grad_output,
|
|
127
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
128
|
+
o,
|
|
129
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
130
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
131
|
+
s,
|
|
132
|
+
s_b, s_b_s, s_r_s, s_c_s,
|
|
133
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
134
|
+
sparsity_reverse_lut_s,
|
|
135
|
+
grad_x,
|
|
136
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
137
|
+
sparsity_block_size))
|
|
138
|
+
|
|
139
|
+
return grad_x
|
|
134
140
|
|
|
135
141
|
|
|
142
|
+
# noinspection PyUnusedLocal
|
|
136
143
|
@triton.autotune(
|
|
137
144
|
configs=get_autotune_configs(),
|
|
138
145
|
key=["sparsity_block_size"],
|
|
@@ -193,6 +200,7 @@ def softmax_kernel(x,
|
|
|
193
200
|
tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
|
|
194
201
|
|
|
195
202
|
|
|
203
|
+
# noinspection PyUnusedLocal
|
|
196
204
|
@triton.autotune(
|
|
197
205
|
configs=get_autotune_configs(),
|
|
198
206
|
key=["sparsity_block_size"],
|
|
@@ -293,4 +301,4 @@ def softmax_setup_context(ctx, inputs, output):
|
|
|
293
301
|
ctx.sparsity_block_size = sparsity_block_size
|
|
294
302
|
|
|
295
303
|
|
|
296
|
-
softmax_forward.register_autograd(
|
|
304
|
+
softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
|
blksprs/ops/transpose.py
CHANGED
|
@@ -28,7 +28,6 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
28
28
|
|
|
29
29
|
"""
|
|
30
30
|
x = x.contiguous()
|
|
31
|
-
x_t = x.transpose(-1, -2).contiguous()
|
|
32
31
|
|
|
33
32
|
validate_dimensions(x)
|
|
34
33
|
validate_contiguous(x)
|
|
@@ -38,20 +37,22 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
38
37
|
|
|
39
38
|
lut = transpose_build_lut(lut, sparsity_layout)
|
|
40
39
|
|
|
41
|
-
return BlksprsTensor(transpose_forward(
|
|
40
|
+
return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
|
|
42
41
|
lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
43
42
|
sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
|
|
44
43
|
|
|
45
44
|
|
|
46
|
-
@triton_op("blksprs::
|
|
45
|
+
@triton_op("blksprs::transpose_forward", mutates_args={})
|
|
47
46
|
def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
48
47
|
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
49
48
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
x_t = x.transpose(-1, -2).contiguous()
|
|
51
|
+
return flow_pull_forward(x_t, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
52
|
+
sparsity_block_size, n_sparse_blocks)
|
|
52
53
|
|
|
53
54
|
|
|
54
|
-
def
|
|
55
|
+
def transpose_wrapper_backward(ctx, grad_output):
|
|
55
56
|
sparsity_layout = ctx.saved_tensors[0]
|
|
56
57
|
sparsity_block_size = ctx.sparsity_block_size
|
|
57
58
|
|
|
@@ -96,4 +97,4 @@ def transpose_setup_context(ctx, inputs, output):
|
|
|
96
97
|
ctx.sparsity_block_size = sparsity_block_size
|
|
97
98
|
|
|
98
99
|
|
|
99
|
-
transpose_forward.register_autograd(
|
|
100
|
+
transpose_forward.register_autograd(transpose_wrapper_backward, setup_context=transpose_setup_context)
|
blksprs/utils/autotuning.py
CHANGED
|
@@ -2,15 +2,7 @@ import os
|
|
|
2
2
|
|
|
3
3
|
blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
|
|
4
4
|
|
|
5
|
-
if blksprs_autotune_mode == "
|
|
6
|
-
autotune_parameters = [
|
|
7
|
-
(16, 3, 8),
|
|
8
|
-
|
|
9
|
-
(32, 3, 8),
|
|
10
|
-
|
|
11
|
-
(64, 3, 8),
|
|
12
|
-
]
|
|
13
|
-
elif blksprs_autotune_mode == "DEFAULT":
|
|
5
|
+
if blksprs_autotune_mode == "DEFAULT":
|
|
14
6
|
autotune_parameters = [
|
|
15
7
|
(16, 3, 8),
|
|
16
8
|
(16, 4, 4),
|
|
@@ -28,6 +20,14 @@ elif blksprs_autotune_mode == "DEFAULT":
|
|
|
28
20
|
(128, 4, 4),
|
|
29
21
|
(128, 5, 2),
|
|
30
22
|
]
|
|
23
|
+
elif blksprs_autotune_mode == "TEST":
|
|
24
|
+
autotune_parameters = [
|
|
25
|
+
(16, 3, 8),
|
|
26
|
+
|
|
27
|
+
(32, 3, 8),
|
|
28
|
+
|
|
29
|
+
(64, 3, 8),
|
|
30
|
+
]
|
|
31
31
|
else:
|
|
32
32
|
raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
|
|
33
33
|
|
|
@@ -75,4 +75,4 @@ def get_autotune_configs():
|
|
|
75
75
|
autotune_configs.append(
|
|
76
76
|
triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
|
|
77
77
|
|
|
78
|
-
return autotune_configs
|
|
78
|
+
return autotune_configs
|
blksprs/utils/processing.py
CHANGED
|
@@ -26,7 +26,6 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
|
|
|
26
26
|
|
|
27
27
|
# Apply weights
|
|
28
28
|
sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
|
|
29
|
-
# TODO At the moment, manual cast is needed. Bug with custom_fwd?
|
|
30
29
|
xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
|
|
31
30
|
interim = xw
|
|
32
31
|
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
import tomllib
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
from torch import Tensor, Size
|
|
3
6
|
|
|
@@ -5,6 +8,11 @@ from torch import Tensor, Size
|
|
|
5
8
|
torch._dynamo.config.capture_scalar_outputs = True
|
|
6
9
|
|
|
7
10
|
|
|
11
|
+
def version():
|
|
12
|
+
with open(Path(__file__).parent.parent.parent.joinpath("pyproject.toml"), "rb") as f:
|
|
13
|
+
return tomllib.load(f)["project"]["version"]
|
|
14
|
+
|
|
15
|
+
|
|
8
16
|
def do_shape_blocksparse(x: Tensor):
|
|
9
17
|
if x.dim() == 3:
|
|
10
18
|
return x.contiguous(), x.size()
|