blksprs 2.0rc7__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +3 -1
- blksprs/layouting/distribution_layout.py +39 -26
- blksprs/layouting/sparsity_layout.py +58 -45
- blksprs/ops/conversion.py +88 -86
- blksprs/ops/distribution.py +80 -78
- blksprs/ops/flow.py +65 -61
- blksprs/ops/matmul.py +50 -55
- blksprs/ops/misc/broadcast_ops.py +28 -27
- blksprs/ops/misc/row_wise.py +123 -125
- blksprs/ops/partitioning.py +12 -10
- blksprs/ops/repeat.py +6 -5
- blksprs/ops/softmax.py +293 -47
- blksprs/ops/transpose.py +8 -7
- blksprs/utils/autotuning.py +10 -10
- blksprs/utils/processing.py +0 -1
- blksprs/utils/tools.py +2 -2
- {blksprs-2.0rc7.dist-info → blksprs-2.1.dist-info}/METADATA +1 -1
- blksprs-2.1.dist-info/RECORD +23 -0
- {blksprs-2.0rc7.dist-info → blksprs-2.1.dist-info}/WHEEL +1 -1
- blksprs-2.0rc7.dist-info/RECORD +0 -23
- {blksprs-2.0rc7.dist-info → blksprs-2.1.dist-info}/top_level.txt +0 -0
blksprs/ops/softmax.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import pdb
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
@@ -7,6 +9,7 @@ from triton import language as tl
|
|
|
7
9
|
|
|
8
10
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
9
11
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
12
|
+
from blksprs.utils.debugging import dbg_tensor_full
|
|
10
13
|
from blksprs.utils.tools import stride
|
|
11
14
|
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
12
15
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
@@ -47,19 +50,13 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
47
50
|
sparsity_block_size))
|
|
48
51
|
|
|
49
52
|
|
|
50
|
-
@triton_op("blksprs::
|
|
53
|
+
@triton_op("blksprs::softmax_forward", mutates_args={})
|
|
51
54
|
def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
52
55
|
sparsity_lut: Tensor,
|
|
53
56
|
sparsity_reverse_lut_rws: Tensor,
|
|
54
57
|
sparsity_block_size: int) -> Tensor:
|
|
55
58
|
output = torch.zeros_like(x)
|
|
56
59
|
|
|
57
|
-
x_b, x_r, x_c = x.size()
|
|
58
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
59
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
60
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
61
|
-
o_b, o_r, o_c = output.size()
|
|
62
|
-
|
|
63
60
|
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
64
61
|
flag_slice_only=True)
|
|
65
62
|
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
|
|
@@ -67,6 +64,11 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
67
64
|
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
68
65
|
flag_slice_only=True)
|
|
69
66
|
|
|
67
|
+
x_b, x_r, x_c = x.size()
|
|
68
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
69
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
70
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
71
|
+
o_b, o_r, o_c = output.size()
|
|
70
72
|
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
71
73
|
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
72
74
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
@@ -89,50 +91,58 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
89
91
|
return output
|
|
90
92
|
|
|
91
93
|
|
|
92
|
-
def
|
|
94
|
+
def softmax_backward_wrapper(ctx, grad_output):
|
|
93
95
|
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
94
96
|
sparsity_block_size = ctx.sparsity_block_size
|
|
95
97
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
98
|
+
return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
|
|
99
|
+
sparsity_block_size), None, None, None, None, None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@triton_op("blksprs::softmax_backward", mutates_args={})
|
|
103
|
+
def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
|
|
104
|
+
sparsity_block_size: int) -> Tensor:
|
|
105
|
+
with torch.no_grad():
|
|
106
|
+
grad_x = torch.zeros_like(o, dtype=torch.float)
|
|
107
|
+
|
|
108
|
+
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
|
|
109
|
+
|
|
110
|
+
sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
|
|
111
|
+
sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
|
|
112
|
+
(sparsity_layout_s_flat == 1) -
|
|
113
|
+
(1 * (sparsity_layout_s_flat == 0)))
|
|
114
|
+
|
|
115
|
+
o_b, o_r, o_c = o.size()
|
|
116
|
+
o_b_s, o_r_s, o_c_s = stride(o)
|
|
117
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
118
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
119
|
+
s_b, s_r, s_c = s.size()
|
|
120
|
+
s_b_s, s_r_s, s_c_s = stride(s)
|
|
121
|
+
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
122
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
123
|
+
|
|
124
|
+
triton_grid = lambda meta: [o_b,
|
|
125
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
126
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
127
|
+
|
|
128
|
+
(wrap_triton(softmax_kernel_grad)[triton_grid]
|
|
129
|
+
(grad_output,
|
|
130
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
131
|
+
o,
|
|
132
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
133
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
134
|
+
s,
|
|
135
|
+
s_b, s_b_s, s_r_s, s_c_s,
|
|
136
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
137
|
+
sparsity_reverse_lut_s,
|
|
138
|
+
grad_x,
|
|
139
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
140
|
+
sparsity_block_size))
|
|
141
|
+
|
|
142
|
+
return grad_x
|
|
134
143
|
|
|
135
144
|
|
|
145
|
+
# noinspection PyUnusedLocal
|
|
136
146
|
@triton.autotune(
|
|
137
147
|
configs=get_autotune_configs(),
|
|
138
148
|
key=["sparsity_block_size"],
|
|
@@ -193,6 +203,7 @@ def softmax_kernel(x,
|
|
|
193
203
|
tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
|
|
194
204
|
|
|
195
205
|
|
|
206
|
+
# noinspection PyUnusedLocal
|
|
196
207
|
@triton.autotune(
|
|
197
208
|
configs=get_autotune_configs(),
|
|
198
209
|
key=["sparsity_block_size"],
|
|
@@ -293,4 +304,239 @@ def softmax_setup_context(ctx, inputs, output):
|
|
|
293
304
|
ctx.sparsity_block_size = sparsity_block_size
|
|
294
305
|
|
|
295
306
|
|
|
296
|
-
softmax_forward.register_autograd(
|
|
307
|
+
softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
311
|
+
def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
312
|
+
lut: dict = None) -> BlksprsTensor:
|
|
313
|
+
"""Computes the softmax fused for each row of a block-sparse tensor in compressed form.
|
|
314
|
+
|
|
315
|
+
Note:
|
|
316
|
+
This softmax implementation is a fused version that loads the entire row of a block-sparse tensor into memory.
|
|
317
|
+
See :func:`softmax` for a true block-wise softmax implementation.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
321
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
322
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
323
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
327
|
+
|
|
328
|
+
"""
|
|
329
|
+
x = x.contiguous()
|
|
330
|
+
|
|
331
|
+
validate_dimensions(x)
|
|
332
|
+
validate_contiguous(x)
|
|
333
|
+
validate_dtype_float_32(x)
|
|
334
|
+
validate_device(x)
|
|
335
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
336
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
337
|
+
|
|
338
|
+
lut = softmax_fused_build_lut(lut, sparsity_layout)
|
|
339
|
+
|
|
340
|
+
return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
|
|
341
|
+
lut["sparsity_reverse_lut"],
|
|
342
|
+
sparsity_block_size))
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@triton_op("blksprs::softmax_fused_forward", mutates_args={})
|
|
346
|
+
def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
|
|
347
|
+
sparsity_reverse_lut: Tensor,
|
|
348
|
+
sparsity_block_size: int) -> Tensor:
|
|
349
|
+
output = torch.zeros_like(x)
|
|
350
|
+
|
|
351
|
+
x_b, x_r, x_c = x.size()
|
|
352
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
353
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
354
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
355
|
+
|
|
356
|
+
triton_grid = lambda meta: [s_l_b,
|
|
357
|
+
s_l_r,
|
|
358
|
+
sparsity_block_size]
|
|
359
|
+
|
|
360
|
+
(wrap_triton(softmax_fused_kernel)[triton_grid]
|
|
361
|
+
(x,
|
|
362
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
363
|
+
output,
|
|
364
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
365
|
+
sparsity_reverse_lut,
|
|
366
|
+
sparsity_block_size))
|
|
367
|
+
|
|
368
|
+
return output
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def softmax_fused_backward_wrapper(ctx, grad_output):
|
|
372
|
+
o, sparsity_layout, sparsity_reverse_lut = ctx.saved_tensors
|
|
373
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
374
|
+
|
|
375
|
+
return softmax_fused_backward(grad_output, o, sparsity_reverse_lut, sparsity_layout,
|
|
376
|
+
sparsity_block_size), None, None, None, None, None
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@triton_op("blksprs::softmax_fused_backward", mutates_args={})
|
|
380
|
+
def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut: Tensor, sparsity_layout: Tensor,
|
|
381
|
+
sparsity_block_size: int) -> Tensor:
|
|
382
|
+
with torch.no_grad():
|
|
383
|
+
grad_x = torch.zeros_like(o)
|
|
384
|
+
|
|
385
|
+
g_b, g_r, g_c = grad_output.size()
|
|
386
|
+
g_b_s, g_r_s, g_c_s = stride(grad_output)
|
|
387
|
+
o_b, o_r, o_c = o.size()
|
|
388
|
+
o_b_s, o_r_s, o_c_s = stride(o)
|
|
389
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
390
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
391
|
+
|
|
392
|
+
triton_grid = lambda meta: [s_l_b,
|
|
393
|
+
s_l_r,
|
|
394
|
+
sparsity_block_size]
|
|
395
|
+
|
|
396
|
+
(wrap_triton(softmax_fused_kernel_grad)[triton_grid]
|
|
397
|
+
(grad_output,
|
|
398
|
+
g_b, g_b_s, g_r_s, g_c_s,
|
|
399
|
+
o,
|
|
400
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
401
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
402
|
+
sparsity_reverse_lut,
|
|
403
|
+
grad_x,
|
|
404
|
+
sparsity_block_size))
|
|
405
|
+
|
|
406
|
+
return grad_x
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
# noinspection PyUnusedLocal
|
|
410
|
+
@triton.autotune(
|
|
411
|
+
configs=get_autotune_configs(),
|
|
412
|
+
key=["sparsity_block_size"],
|
|
413
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
414
|
+
reset_to_zero=["o"]
|
|
415
|
+
)
|
|
416
|
+
@triton.jit
|
|
417
|
+
def softmax_fused_kernel(x,
|
|
418
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
419
|
+
o,
|
|
420
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
|
|
421
|
+
r_lut,
|
|
422
|
+
sparsity_block_size: tl.constexpr,
|
|
423
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
424
|
+
# Get triton block indices
|
|
425
|
+
pid_bat = tl.program_id(axis=0)
|
|
426
|
+
pid_row = tl.program_id(axis=1)
|
|
427
|
+
pid_lin = tl.program_id(axis=2)
|
|
428
|
+
|
|
429
|
+
# Load reverse sparsity indices of row
|
|
430
|
+
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
431
|
+
pid_row * s_l_r_s +
|
|
432
|
+
(tl.arange(0, s_l_c) * s_l_c_s))
|
|
433
|
+
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
434
|
+
blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
|
|
435
|
+
|
|
436
|
+
if (not (tl.min(blk_rev) == -1 and
|
|
437
|
+
tl.max(blk_rev) == -1)):
|
|
438
|
+
# Extend sparsity indices to cover sparsity blocks
|
|
439
|
+
blk_rev_ext = tl.expand_dims(blk_rev, -1)
|
|
440
|
+
blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
|
|
441
|
+
blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
|
|
442
|
+
|
|
443
|
+
# Load line of x
|
|
444
|
+
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
445
|
+
pid_lin * x_r_s +
|
|
446
|
+
(tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
447
|
+
blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
448
|
+
and blk_rev_ext != -1)
|
|
449
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
|
|
450
|
+
|
|
451
|
+
# Compute softmax
|
|
452
|
+
blk_x_softmax = tl.softmax(blk_x)
|
|
453
|
+
|
|
454
|
+
# Store output
|
|
455
|
+
tl.store(o + blk_x_idx, blk_x_softmax, mask=blk_x_mask)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
# noinspection PyUnusedLocal
|
|
459
|
+
@triton.autotune(
|
|
460
|
+
configs=get_autotune_configs(),
|
|
461
|
+
key=["sparsity_block_size"],
|
|
462
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
463
|
+
reset_to_zero=["o"]
|
|
464
|
+
)
|
|
465
|
+
@triton.jit
|
|
466
|
+
def softmax_fused_kernel_grad(g,
|
|
467
|
+
g_b, g_b_s, g_r_s, g_c_s,
|
|
468
|
+
x,
|
|
469
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
470
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
|
|
471
|
+
r_lut,
|
|
472
|
+
o,
|
|
473
|
+
sparsity_block_size: tl.constexpr,
|
|
474
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
475
|
+
# Get triton block indices
|
|
476
|
+
pid_bat = tl.program_id(axis=0)
|
|
477
|
+
pid_row = tl.program_id(axis=1)
|
|
478
|
+
pid_lin = tl.program_id(axis=2)
|
|
479
|
+
|
|
480
|
+
# Load reverse sparsity indices of row
|
|
481
|
+
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
482
|
+
pid_row * s_l_r_s +
|
|
483
|
+
(tl.arange(0, s_l_c) * s_l_c_s))
|
|
484
|
+
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
485
|
+
blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
|
|
486
|
+
|
|
487
|
+
if (not (tl.min(blk_rev) == -1 and
|
|
488
|
+
tl.max(blk_rev) == -1)):
|
|
489
|
+
# Extend sparsity indices to cover sparsity blocks
|
|
490
|
+
blk_rev_ext = tl.expand_dims(blk_rev, -1)
|
|
491
|
+
blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
|
|
492
|
+
blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
|
|
493
|
+
|
|
494
|
+
# Load line of g
|
|
495
|
+
blk_g_idx = (blk_rev_ext * g_b_s +
|
|
496
|
+
pid_lin * g_r_s +
|
|
497
|
+
(tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * g_c_s)
|
|
498
|
+
blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
|
|
499
|
+
and blk_rev_ext != -1)
|
|
500
|
+
blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
|
|
501
|
+
|
|
502
|
+
# Load line of x
|
|
503
|
+
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
504
|
+
pid_lin * x_r_s +
|
|
505
|
+
(tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
506
|
+
blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
507
|
+
and blk_rev_ext != -1)
|
|
508
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
|
|
509
|
+
|
|
510
|
+
# Compute gradients
|
|
511
|
+
blk_grad = blk_x * (blk_g - tl.sum(blk_x * blk_g))
|
|
512
|
+
|
|
513
|
+
tl.store(o + blk_x_idx, blk_grad, mask=blk_x_mask)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
517
|
+
if lut is None:
|
|
518
|
+
lut = dict()
|
|
519
|
+
|
|
520
|
+
if "sparsity_reverse_lut" not in lut:
|
|
521
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
522
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
523
|
+
(sparsity_layout_flat == 1) -
|
|
524
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
525
|
+
.reshape(sparsity_layout.size())
|
|
526
|
+
.reshape(-1).contiguous())
|
|
527
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
528
|
+
|
|
529
|
+
validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut"])
|
|
530
|
+
|
|
531
|
+
return lut
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
# noinspection PyUnusedLocal
|
|
535
|
+
def softmax_fused_setup_context(ctx, inputs, output):
|
|
536
|
+
(_, sparsity_layout, sparsity_reverse_lut, sparsity_block_size) = inputs
|
|
537
|
+
|
|
538
|
+
ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut)
|
|
539
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
softmax_fused_forward.register_autograd(softmax_fused_backward_wrapper, setup_context=softmax_fused_setup_context)
|
blksprs/ops/transpose.py
CHANGED
|
@@ -28,7 +28,6 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
28
28
|
|
|
29
29
|
"""
|
|
30
30
|
x = x.contiguous()
|
|
31
|
-
x_t = x.transpose(-1, -2).contiguous()
|
|
32
31
|
|
|
33
32
|
validate_dimensions(x)
|
|
34
33
|
validate_contiguous(x)
|
|
@@ -38,20 +37,22 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
38
37
|
|
|
39
38
|
lut = transpose_build_lut(lut, sparsity_layout)
|
|
40
39
|
|
|
41
|
-
return BlksprsTensor(transpose_forward(
|
|
40
|
+
return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
|
|
42
41
|
lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
43
42
|
sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
|
|
44
43
|
|
|
45
44
|
|
|
46
|
-
@triton_op("blksprs::
|
|
45
|
+
@triton_op("blksprs::transpose_forward", mutates_args={})
|
|
47
46
|
def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
48
47
|
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
49
48
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
x_t = x.transpose(-1, -2).contiguous()
|
|
51
|
+
return flow_pull_forward(x_t, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
52
|
+
sparsity_block_size, n_sparse_blocks)
|
|
52
53
|
|
|
53
54
|
|
|
54
|
-
def
|
|
55
|
+
def transpose_wrapper_backward(ctx, grad_output):
|
|
55
56
|
sparsity_layout = ctx.saved_tensors[0]
|
|
56
57
|
sparsity_block_size = ctx.sparsity_block_size
|
|
57
58
|
|
|
@@ -96,4 +97,4 @@ def transpose_setup_context(ctx, inputs, output):
|
|
|
96
97
|
ctx.sparsity_block_size = sparsity_block_size
|
|
97
98
|
|
|
98
99
|
|
|
99
|
-
transpose_forward.register_autograd(
|
|
100
|
+
transpose_forward.register_autograd(transpose_wrapper_backward, setup_context=transpose_setup_context)
|
blksprs/utils/autotuning.py
CHANGED
|
@@ -2,15 +2,7 @@ import os
|
|
|
2
2
|
|
|
3
3
|
blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
|
|
4
4
|
|
|
5
|
-
if blksprs_autotune_mode == "
|
|
6
|
-
autotune_parameters = [
|
|
7
|
-
(16, 3, 8),
|
|
8
|
-
|
|
9
|
-
(32, 3, 8),
|
|
10
|
-
|
|
11
|
-
(64, 3, 8),
|
|
12
|
-
]
|
|
13
|
-
elif blksprs_autotune_mode == "DEFAULT":
|
|
5
|
+
if blksprs_autotune_mode == "DEFAULT":
|
|
14
6
|
autotune_parameters = [
|
|
15
7
|
(16, 3, 8),
|
|
16
8
|
(16, 4, 4),
|
|
@@ -28,6 +20,14 @@ elif blksprs_autotune_mode == "DEFAULT":
|
|
|
28
20
|
(128, 4, 4),
|
|
29
21
|
(128, 5, 2),
|
|
30
22
|
]
|
|
23
|
+
elif blksprs_autotune_mode == "TEST":
|
|
24
|
+
autotune_parameters = [
|
|
25
|
+
(16, 3, 8),
|
|
26
|
+
|
|
27
|
+
(32, 3, 8),
|
|
28
|
+
|
|
29
|
+
(64, 3, 8),
|
|
30
|
+
]
|
|
31
31
|
else:
|
|
32
32
|
raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
|
|
33
33
|
|
|
@@ -75,4 +75,4 @@ def get_autotune_configs():
|
|
|
75
75
|
autotune_configs.append(
|
|
76
76
|
triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
|
|
77
77
|
|
|
78
|
-
return autotune_configs
|
|
78
|
+
return autotune_configs
|
blksprs/utils/processing.py
CHANGED
|
@@ -26,7 +26,6 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
|
|
|
26
26
|
|
|
27
27
|
# Apply weights
|
|
28
28
|
sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
|
|
29
|
-
# TODO At the moment, manual cast is needed. Bug with custom_fwd?
|
|
30
29
|
xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
|
|
31
30
|
interim = xw
|
|
32
31
|
|
blksprs/utils/tools.py
CHANGED
|
@@ -5,14 +5,14 @@ from torch import Tensor, Size
|
|
|
5
5
|
torch._dynamo.config.capture_scalar_outputs = True
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def do_shape_blocksparse(x: Tensor):
|
|
8
|
+
def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
|
|
9
9
|
if x.dim() == 3:
|
|
10
10
|
return x.contiguous(), x.size()
|
|
11
11
|
|
|
12
12
|
return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
15
|
+
def undo_shape_blocksparse(x: Tensor, shape: Size | tuple[int, ...]) -> Tensor:
|
|
16
16
|
if x.shape[:-2] == shape[:-2]:
|
|
17
17
|
return x
|
|
18
18
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=o_Rj7fz_70vbMGLePihczVIVcM8E28vY3ah-d1q4ZO0,1613
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
|
|
4
|
+
blksprs/ops/conversion.py,sha256=kf5HKofZ4nVeHCIqQoYKiIlgsAhq33Tnmnr1c17Fkqs,21906
|
|
5
|
+
blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
|
|
6
|
+
blksprs/ops/flow.py,sha256=qdWBCLDSkKaa8CAfkO1NgH-J5N7yMsILyR7qEpyrIUU,8246
|
|
7
|
+
blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
|
|
8
|
+
blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
|
|
9
|
+
blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
|
|
10
|
+
blksprs/ops/softmax.py,sha256=H0OxST_XX1QLa7HDTDHznzibVHAxnp5sVbMU32HLxf0,21967
|
|
11
|
+
blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
|
|
12
|
+
blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
|
|
13
|
+
blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
|
|
14
|
+
blksprs/utils/autotuning.py,sha256=a-kmWRjJ3eED2XbjkQeOJSyW8bdIs27HgKMPvAKqWeU,2052
|
|
15
|
+
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
|
+
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
+
blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
|
|
18
|
+
blksprs/utils/tools.py,sha256=CPf7viQ2OTcZFrB1aSL8_us4VE9M6YEfDz2dE30jr9I,715
|
|
19
|
+
blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
|
|
20
|
+
blksprs-2.1.dist-info/METADATA,sha256=uPVm8Y7fX5iModz6j3hNAftdtauCsJ-iYrMa-Pv3xnU,9506
|
|
21
|
+
blksprs-2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
blksprs-2.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-2.1.dist-info/RECORD,,
|
blksprs-2.0rc7.dist-info/RECORD
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
|
|
2
|
-
blksprs/layouting/distribution_layout.py,sha256=TkMh_DYKX56Cb8Vq7EHyupMRvzm0XbUNP8QP7afv9wM,5122
|
|
3
|
-
blksprs/layouting/sparsity_layout.py,sha256=6GOjwllDUK9L8jEQNu2i17Pp1BIIQm8fv3xVuiR0zIw,10228
|
|
4
|
-
blksprs/ops/conversion.py,sha256=2zAdbaZ1iP2lisLVeG-k-f571G4HJapADhSwpY0Zd3o,21503
|
|
5
|
-
blksprs/ops/distribution.py,sha256=6joac_zl3ZnRkPqLPQ0d88r7IbcrWAg0HiV93LOZw-w,20453
|
|
6
|
-
blksprs/ops/flow.py,sha256=UO5ba5TFgVpEyT7r0hnWYw3vhRDpBOxyPHUBeNOAYPs,7935
|
|
7
|
-
blksprs/ops/matmul.py,sha256=02hujXMtFgF7ohepM3v6h9okrfcU-J3mQZV17B-qvh0,12235
|
|
8
|
-
blksprs/ops/partitioning.py,sha256=nAV28f3NtvT4OFvDtnE0A-VxpDQmMXS0pZw4CJwzqGA,9838
|
|
9
|
-
blksprs/ops/repeat.py,sha256=bQpJuwtt8aRdSzxT78lJ8f8fLDhPkYK5UvMfJ-PQrkc,8977
|
|
10
|
-
blksprs/ops/softmax.py,sha256=-NoTf1Cpuku9C99N0LuMydT_ObozWTnZJGDZxseXEXI,12209
|
|
11
|
-
blksprs/ops/transpose.py,sha256=PQKteFnzNAOEC7voO7wh_dq9c54UjCboJz889aBCwKc,4010
|
|
12
|
-
blksprs/ops/misc/broadcast_ops.py,sha256=DhUbliT9TBT6zlEjutBmY1EAEUPmYOt2mKQ5i46vN1c,5880
|
|
13
|
-
blksprs/ops/misc/row_wise.py,sha256=5u_J8WOTepvf6XtZ8r0lLPofYrI5fGB7mxSmGC81IR0,19167
|
|
14
|
-
blksprs/utils/autotuning.py,sha256=tDfMWklm2rvbo0-ahH81C3Gg0U6LHjPn3d_3pEOzmJs,2053
|
|
15
|
-
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
|
-
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
-
blksprs/utils/processing.py,sha256=xuu9iDpwTvsqI_WKMSD8QCNuvPnfcKMRcuF2L4Zs6Ts,3808
|
|
18
|
-
blksprs/utils/tools.py,sha256=3_2IBbd54vVU4-6m2KtAN7qjU6jeF4UfPkbjeFqMpYo,664
|
|
19
|
-
blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
|
|
20
|
-
blksprs-2.0rc7.dist-info/METADATA,sha256=ER9DHdVeYUZUsjE-2bEB9fePw0FVI1vknwPNrj7mDPE,9509
|
|
21
|
-
blksprs-2.0rc7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
22
|
-
blksprs-2.0rc7.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
-
blksprs-2.0rc7.dist-info/RECORD,,
|
|
File without changes
|