blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/flow.py CHANGED
@@ -1,20 +1,65 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.utils.tools import stride, get_triton_block_size
7
-
8
-
8
+ from blksprs.utils.tools import stride
9
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
+
11
+
12
+ @triton_op("blksprs::flow_pull_forward", mutates_args={})
13
+ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
14
+ sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
15
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
16
+ with torch.no_grad():
17
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
18
+ dtype=x.dtype, device=x.device)
19
+
20
+ x_b, x_r, x_c = x.size()
21
+ x_b_s, x_r_s, x_c_s = stride(x)
22
+ o_b, o_r, o_c = output.size()
23
+ o_b_s, o_r_s, o_c_s = stride(output)
24
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
25
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
26
+ s_lut_r, s_lut_c = sparsity_lut.size()
27
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
28
+
29
+ triton_grid = lambda meta: [o_b,
30
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
31
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
32
+
33
+ (wrap_triton(flow_pull_kernel)[triton_grid]
34
+ (x,
35
+ x_b, x_b_s, x_r_s, x_c_s,
36
+ output,
37
+ o_b, o_b_s, o_r_s, o_c_s,
38
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
39
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
40
+ sparsity_reverse_lut,
41
+ sparsity_block_size))
42
+
43
+ return output
44
+
45
+
46
+ # noinspection PyUnusedLocal
47
+ @triton.autotune(
48
+ configs=get_autotune_configs(),
49
+ key=["sparsity_block_size"],
50
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
51
+ reset_to_zero=["o"]
52
+ )
9
53
  @triton.jit
10
- def kernel_blocksparse_flow_pull(x,
11
- x_b, x_b_s, x_r_s, x_c_s,
12
- o,
13
- o_b, o_b_s, o_r_s, o_c_s,
14
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
15
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
16
- r_lut,
17
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
54
+ def flow_pull_kernel(x,
55
+ x_b, x_b_s, x_r_s, x_c_s,
56
+ o,
57
+ o_b, o_b_s, o_r_s, o_c_s,
58
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
59
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
60
+ r_lut,
61
+ sparsity_block_size,
62
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
18
63
  # Get triton block indices
19
64
  pid_blk = tl.program_id(axis=0)
20
65
  pid_row = tl.program_id(axis=1)
@@ -40,32 +85,72 @@ def kernel_blocksparse_flow_pull(x,
40
85
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
41
86
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
42
87
 
43
- if rev_idx_spa == -1:
44
- tl.device_assert(False)
45
- return
46
-
47
- blk_x_idx = (rev_idx_spa * x_b_s +
48
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
49
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
50
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
51
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
52
-
53
- blk_o_idx = (pid_blk * o_b_s +
54
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
55
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
56
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
57
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
58
-
59
-
88
+ if rev_idx_spa >= 0:
89
+ blk_x_idx = (rev_idx_spa * x_b_s +
90
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
91
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
92
+ blk_x_msk = (blk_x_idx >= 0 and
93
+ blk_x_idx < x_b * x_b_s)
94
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
95
+
96
+ blk_o_idx = (pid_blk * o_b_s +
97
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
98
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
99
+ blk_o_msk = (blk_o_idx >= 0 and
100
+ blk_o_idx < o_b * o_b_s)
101
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
102
+
103
+
104
+ @triton_op("blksprs::flow_push_forward", mutates_args={})
105
+ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
106
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
107
+ with torch.no_grad():
108
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
109
+ dtype=x.dtype, device=x.device)
110
+
111
+ x_b, x_r, x_c = x.size()
112
+ x_b_s, x_r_s, x_c_s = stride(x)
113
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
114
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
115
+ s_lut_r, s_lut_c = sparsity_lut.size()
116
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
117
+ o_b, o_r, o_c = output.size()
118
+ o_b_s, o_r_s, o_c_s = stride(output)
119
+
120
+ triton_grid = lambda meta: [x_b,
121
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
122
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
123
+
124
+ (wrap_triton(flow_push_kernel)[triton_grid]
125
+ (x,
126
+ x_b, x_b_s, x_r_s, x_c_s,
127
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
128
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
129
+ sparsity_reverse_lut,
130
+ output,
131
+ o_b, o_b_s, o_r_s, o_c_s,
132
+ sparsity_block_size))
133
+
134
+ return output
135
+
136
+
137
+ # noinspection PyUnusedLocal
138
+ @triton.autotune(
139
+ configs=get_autotune_configs(),
140
+ key=["sparsity_block_size"],
141
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
142
+ reset_to_zero=["o"]
143
+ )
60
144
  @triton.jit
61
- def kernel_blocksparse_flow_push(x,
62
- x_b, x_b_s, x_r_s, x_c_s,
63
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
64
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
65
- r_lut,
66
- o,
67
- o_b, o_b_s, o_r_s, o_c_s,
68
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
145
+ def flow_push_kernel(x,
146
+ x_b, x_b_s, x_r_s, x_c_s,
147
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
148
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
149
+ r_lut,
150
+ o,
151
+ o_b, o_b_s, o_r_s, o_c_s,
152
+ sparsity_block_size,
153
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
69
154
  # Get triton block indices
70
155
  pid_blk = tl.program_id(axis=0)
71
156
  pid_row = tl.program_id(axis=1)
@@ -91,56 +176,17 @@ def kernel_blocksparse_flow_push(x,
91
176
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
92
177
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
93
178
 
94
- if rev_idx_spa == -1:
95
- tl.device_assert(False)
96
- return
97
-
98
- blk_x_idx = (pid_blk * x_b_s +
99
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
100
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
101
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
102
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
103
-
104
- blk_o_idx = (rev_idx_spa * o_b_s +
105
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
106
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
107
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
108
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
109
-
110
-
111
- def flow_forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
112
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
113
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
114
- dtype=x.dtype, device=x.device)
115
-
116
- x_b, x_r, x_c = x.size()
117
- x_b_s, x_r_s, x_c_s = stride(x)
118
- o_b, o_r, o_c = output.size()
119
- o_b_s, o_r_s, o_c_s = stride(output)
120
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
121
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
122
- s_lut_r, s_lut_c = sparsity_lut.size()
123
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
124
-
125
- if triton_block_size is None:
126
- triton_block_size = get_triton_block_size(sparsity_block_size)
127
-
128
- triton_grid = lambda meta: [o_b,
129
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
130
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
131
-
132
- (kernel_blocksparse_flow_pull[triton_grid]
133
- (x,
134
- x_b, x_b_s, x_r_s, x_c_s,
135
- output,
136
- o_b, o_b_s, o_r_s, o_c_s,
137
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
138
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
139
- sparsity_reverse_lut,
140
- triton_block_size))
141
-
142
- # Save for backward pass
143
- ctx.sparsity_block_size = sparsity_block_size
144
- ctx.triton_block_size = triton_block_size
145
-
146
- return output
179
+ if rev_idx_spa >= 0:
180
+ blk_x_idx = (pid_blk * x_b_s +
181
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
182
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
183
+ blk_x_msk = (blk_x_idx >= 0 and
184
+ blk_x_idx < x_b * x_b_s)
185
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
186
+
187
+ blk_o_idx = (rev_idx_spa * o_b_s +
188
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
189
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
190
+ blk_o_msk = (blk_o_idx >= 0 and
191
+ blk_o_idx < o_b * o_b_s)
192
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
blksprs/ops/matmul.py CHANGED
@@ -1,19 +1,22 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch.library import triton_op, wrap_triton
4
5
  from triton import language as tl
5
6
 
6
7
  from blksprs.ops.transpose import transpose
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
12
+ validate_sparsity, validate_sparsity_block_size, validate_dtype_float
11
13
 
12
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
13
16
  def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
14
17
  y: BlksprsTensor, sparsity_layout_y: Tensor,
15
18
  sparsity_layout_output: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
19
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
17
20
  """Performs matrix multiplication between two block-sparse tensors.
18
21
 
19
22
  The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
@@ -25,7 +28,7 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
25
28
  sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
26
29
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
27
30
  sparsity_block_size (int): The size of the sparsity blocks.
28
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
31
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
29
32
 
30
33
  Returns:
31
34
  BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
@@ -42,44 +45,24 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
42
45
  if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
43
46
  raise ValueError("Inner dimensions of tensors must match")
44
47
  validate_sparsity_block_size(sparsity_block_size, x, y)
45
- validate_triton_block_size(triton_block_size, sparsity_block_size)
46
48
 
47
- sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
48
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
49
- (sparsity_layout_x_flat == 1) -
50
- (1 * (sparsity_layout_x_flat == 0)))
49
+ lut = matmul_build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
51
50
 
52
- sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
53
- sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
54
- (sparsity_layout_y_flat == 1) -
55
- (1 * (sparsity_layout_y_flat == 0)))
51
+ return BlksprsTensor(matmul_forward(x, y,
52
+ sparsity_layout_x, lut["sparsity_reverse_lut_x"],
53
+ sparsity_layout_y, lut["sparsity_reverse_lut_y"],
54
+ sparsity_layout_output, lut["sparsity_lut_o"],
55
+ sparsity_block_size, lut["n_sparse_blocks"]))
56
56
 
57
- sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
58
57
 
59
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
60
-
61
- validate_contiguous(sparsity_layout_x, sparsity_reverse_lut_x,
62
- sparsity_layout_y, sparsity_reverse_lut_y,
63
- sparsity_layout_output, sparsity_lut_o)
64
-
65
- return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
66
- sparsity_layout_x, sparsity_reverse_lut_x,
67
- sparsity_layout_y, sparsity_reverse_lut_y,
68
- sparsity_layout_output, sparsity_lut_o,
69
- sparsity_block_size,
70
- n_sparse_blocks,
71
- triton_block_size))
72
-
73
-
74
- class _BlocksparseMatmulSSS(torch.autograd.Function):
75
-
76
- @staticmethod
77
- def forward(ctx, x: Tensor, y: Tensor,
78
- sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
79
- sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
80
- sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
81
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
82
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
58
+ @triton_op("blksprs::matmul_forward", mutates_args={})
59
+ def matmul_forward(x: Tensor, y: Tensor,
60
+ sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
61
+ sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
62
+ _: Tensor, sparsity_lut_o: Tensor,
63
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
64
+ with torch.no_grad():
65
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
83
66
  dtype=x.dtype, device=x.device)
84
67
 
85
68
  x_b, x_r, x_c = x.size()
@@ -95,133 +78,183 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
95
78
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
96
79
  s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
97
80
 
98
- if triton_block_size is None:
99
- triton_block_size = get_triton_block_size(sparsity_block_size)
100
-
101
81
  triton_grid = lambda meta: [o_b,
102
82
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
103
83
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
104
84
 
105
- (_BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid]
85
+ (wrap_triton(matmul_kernel)[triton_grid]
106
86
  (x,
107
87
  x_b, x_b_s, x_r_s, x_c_s,
108
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
88
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s,
89
+ s_l_x_c, s_l_x_c_s,
109
90
  sparsity_reverse_lut_x,
110
91
  y,
111
92
  y_b, y_b_s, y_r_s, y_c_s,
112
- s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
93
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
94
+ s_l_y_c_s,
113
95
  sparsity_reverse_lut_y,
114
96
  output,
115
97
  o_b, o_b_s, o_r_s, o_c_s,
116
98
  sparsity_lut_o,
117
99
  s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
118
- sparsity_block_size,
119
- triton_block_size))
120
-
121
- ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
122
- ctx.sparsity_block_size = sparsity_block_size
123
- ctx.triton_block_size = triton_block_size
100
+ sparsity_block_size))
124
101
 
125
102
  return output
126
103
 
127
- @staticmethod
128
- def backward(ctx, grad_output):
129
- x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
130
- sparsity_block_size = ctx.sparsity_block_size
131
- triton_block_size = ctx.triton_block_size
132
-
133
- x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size, triton_block_size)
134
- y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size, triton_block_size)
135
-
136
- grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
137
- sparsity_block_size, triton_block_size)
138
- grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
139
- sparsity_block_size, triton_block_size)
140
-
141
- return grad_x, grad_y, None, None, None, None, None, None, None, None, None
142
-
143
- @staticmethod
144
- @triton.jit
145
- def kernel_blocksparse_matmul_sss(x,
146
- x_b, x_b_s, x_r_s, x_c_s,
147
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
148
- r_lut_x,
149
- y,
150
- y_b, y_b_s, y_r_s, y_c_s,
151
- s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
152
- r_lut_y,
153
- o,
154
- o_b, o_b_s, o_r_s, o_c_s,
155
- s_lut_o,
156
- s_lut_o_r, s_lut_o_r_s,
157
- s_lut_o_c_s,
158
- sparsity_block_size,
159
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
160
- # Get triton block indices
161
- pid_blk = tl.program_id(axis=0)
162
- pid_row = tl.program_id(axis=1)
163
- pid_col = tl.program_id(axis=2)
164
-
165
- # Get position of current sparsity block consisting of its batch, row, and column index
166
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
167
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
168
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
169
-
170
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
171
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
172
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
173
-
174
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
175
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
176
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
177
-
178
- # Setup buffer
179
- buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
180
-
181
- # Slide over triton block sized segments of input tensors
182
- for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
183
- # Convert to segment index of sparsity layout
184
- i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
185
- # Calculate the triton segment index within a block
186
- i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
187
-
188
- # Get reverse sparsity indices for input tensors x and y
189
- # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
190
-
191
- # Get reverse sparsity indices for x
192
- rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
193
- spa_row_o * s_l_x_r_s +
194
- i_seg_spa * s_l_x_c_s)
195
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
196
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
197
-
198
- # Get reverse sparsity indices for y
199
- rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
200
- rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
201
- rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
202
-
203
- # If both blocks are present commence calculation
204
- if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
205
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
206
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
207
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
208
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
209
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
210
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
211
-
212
- blk_y_idx = ((rev_idx_spa_y * y_b_s) +
213
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
214
- tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
215
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
216
- blk_y_msk = (blk_y_idx >= 0 and blk_y_idx < y_b * y_b_s)
217
- blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
218
-
219
- # Perform matrix multiplication
220
- buf += tl.dot(blk_x, blk_y, input_precision="tf32")
221
-
222
- # Store output
223
- blk_o_idx = ((pid_blk * o_b_s) +
224
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
225
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
226
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
227
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
104
+
105
+ def matmul_wrapper_backward(ctx, grad_output):
106
+ x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
107
+ sparsity_block_size = ctx.sparsity_block_size
108
+
109
+ x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size)
110
+ y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size)
111
+
112
+ grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
113
+ sparsity_block_size)
114
+ grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
115
+ sparsity_block_size)
116
+
117
+ return grad_x, grad_y, None, None, None, None, None, None, None, None
118
+
119
+
120
+ @triton.autotune(
121
+ configs=get_autotune_configs(),
122
+ key=["sparsity_block_size"],
123
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
124
+ reset_to_zero=["o"]
125
+ )
126
+ @triton.jit
127
+ def matmul_kernel(x,
128
+ x_b, x_b_s, x_r_s, x_c_s,
129
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
130
+ r_lut_x,
131
+ y,
132
+ y_b, y_b_s, y_r_s, y_c_s,
133
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
134
+ r_lut_y,
135
+ o,
136
+ o_b, o_b_s, o_r_s, o_c_s,
137
+ s_lut_o,
138
+ s_lut_o_r, s_lut_o_r_s,
139
+ s_lut_o_c_s,
140
+ sparsity_block_size,
141
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
142
+ # Get triton block indices
143
+ pid_blk = tl.program_id(axis=0)
144
+ pid_row = tl.program_id(axis=1)
145
+ pid_col = tl.program_id(axis=2)
146
+
147
+ # Get position of current sparsity block consisting of its batch, row, and column index
148
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
149
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
150
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
151
+
152
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
153
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
154
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
155
+
156
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
157
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
158
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
159
+
160
+ # Setup buffer
161
+ buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
162
+
163
+ # Slide over triton block sized segments of input tensors
164
+ for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
165
+ # Convert to segment index of sparsity layout
166
+ i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
167
+ # Calculate the triton segment index within a block
168
+ i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
169
+
170
+ # Get reverse sparsity indices for input tensors x and y
171
+ # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
172
+
173
+ # Get reverse sparsity indices for x
174
+ rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
175
+ spa_row_o * s_l_x_r_s +
176
+ i_seg_spa * s_l_x_c_s)
177
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
178
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
179
+
180
+ # Get reverse sparsity indices for y
181
+ rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
182
+ rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
183
+ rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
184
+
185
+ # If both blocks are present commence calculation
186
+ if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
187
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
188
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
189
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
190
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
191
+ blk_x_msk = (blk_x_idx >= 0 and
192
+ blk_x_idx < x_b * x_b_s)
193
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
194
+
195
+ blk_y_idx = ((rev_idx_spa_y * y_b_s) +
196
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
197
+ tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
198
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
199
+ blk_y_msk = (blk_y_idx >= 0 and
200
+ blk_y_idx < y_b * y_b_s)
201
+ blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
202
+
203
+ # Perform matrix multiplication
204
+ buf += tl.dot(blk_x, blk_y)
205
+
206
+ # Cast buffer
207
+ buf = buf.to(o.dtype.element_ty)
208
+
209
+ # Store output
210
+ blk_o_idx = ((pid_blk * o_b_s) +
211
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
212
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
213
+ blk_o_msk = (blk_o_idx >= 0 and
214
+ blk_o_idx < o_b * o_b_s)
215
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
216
+
217
+
218
+ def matmul_build_lut(lut: dict, sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor):
219
+ if lut is None:
220
+ lut = dict()
221
+
222
+ if "sparsity_reverse_lut_x" not in lut:
223
+ sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
224
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
225
+ (sparsity_layout_x_flat == 1) -
226
+ (1 * (sparsity_layout_x_flat == 0)))
227
+ lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
228
+
229
+ if "sparsity_reverse_lut_y" not in lut:
230
+ sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
231
+ sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
232
+ (sparsity_layout_y_flat == 1) -
233
+ (1 * (sparsity_layout_y_flat == 0)))
234
+ lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
235
+
236
+ if "sparsity_lut_o" not in lut:
237
+ sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
238
+ lut["sparsity_lut_o"] = sparsity_lut_o
239
+
240
+ if "n_sparse_blocks" not in lut:
241
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
242
+ lut["n_sparse_blocks"] = n_sparse_blocks
243
+
244
+ validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
245
+ sparsity_layout_y, lut["sparsity_reverse_lut_y"],
246
+ sparsity_layout_output, lut["sparsity_lut_o"])
247
+
248
+ return lut
249
+
250
+
251
+ # noinspection PyUnusedLocal
252
+ def matmul_setup_context(ctx, inputs, output):
253
+ (x, y, sparsity_layout_x, _, sparsity_layout_y, _,
254
+ sparsity_layout_o, _, sparsity_block_size, _) = inputs
255
+
256
+ ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
257
+ ctx.sparsity_block_size = sparsity_block_size
258
+
259
+
260
+ matmul_forward.register_autograd(matmul_wrapper_backward, setup_context=matmul_setup_context)