blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,22 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.ops.conversion import to_dense
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
12
+ validate_sparsity, validate_dtype_int, validate_sparsity_block_size
11
13
 
12
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
13
16
  def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
14
17
  dim: int,
15
18
  idx: BlksprsTensor, sparsity_layout_idx: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
19
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
17
20
  """Applies a gather operation on a block-sparse tensor in compressed form.
18
21
 
19
22
  Args:
@@ -23,7 +26,6 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
23
26
  idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
24
27
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
25
28
  sparsity_block_size (int): The size of the sparsity blocks.
26
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
27
29
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
28
30
 
29
31
  Returns:
@@ -39,45 +41,22 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
39
41
  validate_device(src, idx)
40
42
  validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
41
43
  validate_sparsity_block_size(sparsity_block_size, src, idx)
42
- validate_triton_block_size(triton_block_size, sparsity_block_size)
43
44
 
44
45
  adjusted_dim = dim % 3
45
46
 
46
- lut = _BlocksparseGather.build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
47
+ lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
47
48
 
48
- return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
49
- adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
50
- sparsity_block_size, triton_block_size))
49
+ return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
50
+ adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
51
+ sparsity_block_size))
51
52
 
52
53
 
53
- class _BlocksparseGather(torch.autograd.Function):
54
-
55
- @staticmethod
56
- def build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
57
- if lut is None:
58
- lut = dict()
59
-
60
- if "sparsity_reverse_lut_x" not in lut:
61
- sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
62
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
63
- (sparsity_layout_x_flat == 1) -
64
- (1 * (sparsity_layout_x_flat == 0)))
65
- lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
66
-
67
- if "sparsity_lut_i" not in lut:
68
- sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
69
- lut["sparsity_lut_i"] = sparsity_lut_i
70
-
71
- validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
72
- sparsity_layout_idx, lut["sparsity_lut_i"])
73
-
74
- return lut
75
-
76
- @staticmethod
77
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
78
- dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
79
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
80
- output = torch.empty_like(i, dtype=x.dtype)
54
+ @triton_op("blksprs::gather_forward", mutates_args={})
55
+ def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
56
+ dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
57
+ sparsity_block_size: int) -> Tensor:
58
+ with torch.no_grad():
59
+ output = torch.zeros_like(i, dtype=x.dtype)
81
60
 
82
61
  x_b, x_r, x_c = x.size()
83
62
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -90,14 +69,11 @@ class _BlocksparseGather(torch.autograd.Function):
90
69
  o_b, o_r, o_c = output.size()
91
70
  o_b_s, o_r_s, o_c_s = stride(output)
92
71
 
93
- if triton_block_size is None:
94
- triton_block_size = get_triton_block_size(sparsity_block_size)
95
-
96
72
  triton_grid = lambda meta: [o_b,
97
73
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
98
74
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
99
75
 
100
- (_BlocksparseGather.kernel_blocksparse_gather[triton_grid]
76
+ (wrap_triton(gather_kernel)[triton_grid]
101
77
  (x,
102
78
  x_b, x_b_s, x_r_s, x_c_s,
103
79
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
@@ -108,115 +84,152 @@ class _BlocksparseGather(torch.autograd.Function):
108
84
  output,
109
85
  o_b, o_b_s, o_r_s, o_c_s,
110
86
  sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
111
- sparsity_block_size,
112
- triton_block_size))
113
-
114
- ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
115
- ctx.dim = dim
116
- ctx.sparsity_block_size = sparsity_block_size
117
- ctx.triton_block_size = triton_block_size
87
+ sparsity_block_size))
118
88
 
119
89
  return output
120
90
 
121
- @staticmethod
122
- def backward(ctx, grad_output):
123
- sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
124
- dim = ctx.dim
125
- sparsity_block_size = ctx.sparsity_block_size
126
- triton_block_size = ctx.triton_block_size
127
-
128
- return scatter_reduce(grad_output, sparsity_layout_i,
129
- dim, i,
130
- sparsity_layout_x, sparsity_block_size,
131
- reduce_op="sum",
132
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
133
-
134
- @staticmethod
135
- @triton.jit
136
- def kernel_blocksparse_gather(x,
137
- x_b, x_b_s, x_r_s, x_c_s,
138
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
139
- r_lut_x,
140
- dim,
141
- i,
142
- i_b, i_b_s, i_r_s, i_c_s,
143
- o,
144
- o_b, o_b_s, o_r_s, o_c_s,
145
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
146
- sparsity_block_size,
147
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
148
- # Get triton block indices
149
- pid_blk = tl.program_id(axis=0)
150
- pid_row = tl.program_id(axis=1)
151
- pid_col = tl.program_id(axis=2)
152
-
153
- # Get position of current sparsity block consisting of its batch, row, and column index
154
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
155
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
156
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
157
-
158
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
159
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
160
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
161
-
162
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
163
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
164
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
165
-
166
- # Load index values
167
- blk_i_idx = ((pid_blk * i_b_s) +
168
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
169
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
170
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
171
- blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
172
-
173
- # Get indices of sparsity blocks and positions within the blocks
174
- pos_spa_blk_x = blk_i // sparsity_block_size
175
- pos_spa_int_x = blk_i % sparsity_block_size
176
-
177
- rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
178
- rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
179
- rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
180
- dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
181
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
182
- dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
183
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
184
- if dim == 0:
185
- rev_dst_bat_x = blk_i
186
- elif dim == 1:
187
- rev_dst_row_x = pos_spa_blk_x
188
- dst_row_x = pos_spa_int_x * x_r_s
189
- elif dim == 2:
190
- rev_dst_col_x = pos_spa_blk_x
191
- dst_col_x = pos_spa_int_x * x_c_s
192
-
193
- # Load reverse sparsity indices for x
194
- rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
195
- (rev_dst_row_x * s_l_x_r_s) +
196
- (rev_dst_col_x * s_l_x_c_s))
197
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
198
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
199
-
200
- # Load x values
201
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
202
- dst_row_x +
203
- dst_col_x)
204
- blk_x_msk = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s) and rev_idx_spa_x_msk != -1)
205
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
206
-
207
- # Store output
208
- blk_o_idx = ((pid_blk * o_b_s) +
209
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
210
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
211
- blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_x_msk != -1)
212
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
91
+
92
+ def gather_wrapper_backward(ctx, grad_output):
93
+ sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
94
+ dim = ctx.dim
95
+ sparsity_block_size = ctx.sparsity_block_size
96
+
97
+ return scatter_reduce(grad_output, sparsity_layout_i,
98
+ dim, i,
99
+ sparsity_layout_x, sparsity_block_size,
100
+ reduce_op="sum"), None, None, None, None, None, None, None
101
+
102
+
103
+ @triton.autotune(
104
+ configs=get_autotune_configs(),
105
+ key=["sparsity_block_size"],
106
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
107
+ reset_to_zero=["o"]
108
+ )
109
+ @triton.jit
110
+ def gather_kernel(x,
111
+ x_b, x_b_s, x_r_s, x_c_s,
112
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
113
+ r_lut_x,
114
+ dim,
115
+ i,
116
+ i_b, i_b_s, i_r_s, i_c_s,
117
+ o,
118
+ o_b, o_b_s, o_r_s, o_c_s,
119
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
120
+ sparsity_block_size,
121
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
122
+ # Get triton block indices
123
+ pid_blk = tl.program_id(axis=0)
124
+ pid_row = tl.program_id(axis=1)
125
+ pid_col = tl.program_id(axis=2)
126
+
127
+ # Get position of current sparsity block consisting of its batch, row, and column index
128
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
129
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
130
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
131
+
132
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
133
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
134
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
135
+
136
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
137
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
138
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
139
+
140
+ # Load index values
141
+ blk_i_idx = ((pid_blk * i_b_s) +
142
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
143
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
144
+ blk_i_msk = (blk_i_idx >= 0 and
145
+ blk_i_idx < i_b * i_b_s)
146
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
147
+
148
+ # Get indices of sparsity blocks and positions within the blocks
149
+ pos_spa_blk_x = blk_i // sparsity_block_size
150
+ pos_spa_int_x = blk_i % sparsity_block_size
151
+
152
+ rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
153
+ rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
154
+ rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
155
+ dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
156
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
157
+ dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
158
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
159
+ if dim == 0:
160
+ rev_dst_bat_x = blk_i
161
+ elif dim == 1:
162
+ rev_dst_row_x = pos_spa_blk_x
163
+ dst_row_x = pos_spa_int_x * x_r_s
164
+ elif dim == 2:
165
+ rev_dst_col_x = pos_spa_blk_x
166
+ dst_col_x = pos_spa_int_x * x_c_s
167
+
168
+ # Load reverse sparsity indices for x
169
+ rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
170
+ (rev_dst_row_x * s_l_x_r_s) +
171
+ (rev_dst_col_x * s_l_x_c_s))
172
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
173
+ rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
174
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
175
+
176
+ # Load x values
177
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
178
+ dst_row_x +
179
+ dst_col_x)
180
+ blk_x_msk = ((blk_x_idx >= 0 and
181
+ blk_x_idx < x_b * x_b_s) and
182
+ rev_idx_spa_x_msk != -1)
183
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
184
+
185
+ # Store output
186
+ blk_o_idx = ((pid_blk * o_b_s) +
187
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
188
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
189
+ blk_o_msk = ((blk_o_idx >= 0 and
190
+ blk_o_idx < o_b * o_b_s) and
191
+ rev_idx_spa_x_msk != -1)
192
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
193
+
194
+
195
+ def gather_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
196
+ if lut is None:
197
+ lut = dict()
198
+
199
+ if "sparsity_reverse_lut_x" not in lut:
200
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
201
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
202
+ (sparsity_layout_x_flat == 1) -
203
+ (1 * (sparsity_layout_x_flat == 0)))
204
+ lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
205
+
206
+ if "sparsity_lut_i" not in lut:
207
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
208
+ lut["sparsity_lut_i"] = sparsity_lut_i
209
+
210
+ validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
211
+ sparsity_layout_idx, lut["sparsity_lut_i"])
212
+
213
+ return lut
214
+
215
+
216
+ # noinspection PyUnusedLocal
217
+ def gather_setup_context(ctx, inputs, output):
218
+ (_, sparsity_layout_x, _, dim, i, sparsity_layout_i, _, sparsity_block_size) = inputs
219
+
220
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
221
+ ctx.dim = dim
222
+ ctx.sparsity_block_size = sparsity_block_size
223
+
224
+
225
+ gather_forward.register_autograd(gather_wrapper_backward, setup_context=gather_setup_context)
213
226
 
214
227
 
215
228
  def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
216
229
  dim: int,
217
230
  idx: BlksprsTensor,
218
231
  sparsity_layout_tgt: Tensor,
219
- sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
232
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
220
233
  """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
221
234
 
222
235
  """
@@ -225,15 +238,16 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
225
238
  idx,
226
239
  sparsity_layout_tgt,
227
240
  sparsity_block_size,
228
- reduce_op="none", triton_block_size=triton_block_size)
241
+ reduce_op="none", lut=lut)
229
242
 
230
243
 
244
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
231
245
  def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
232
246
  dim: int,
233
247
  idx: BlksprsTensor,
234
248
  sparsity_layout_tgt: Tensor,
235
249
  sparsity_block_size: int,
236
- reduce_op: str = "sum", triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
250
+ reduce_op: str = "sum", lut: dict = None) -> BlksprsTensor:
237
251
  """Applies a scatter operation on a block-sparse tensor in compressed form.
238
252
 
239
253
  Args:
@@ -245,7 +259,6 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
245
259
  sparsity_block_size (int): The size of the sparsity blocks.
246
260
  reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
247
261
  Supported operations are ``"none"`` and ``"sum"``.
248
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
249
262
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
250
263
 
251
264
  Returns:
@@ -261,55 +274,28 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
261
274
  validate_device(src, idx)
262
275
  validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
263
276
  validate_sparsity_block_size(sparsity_block_size, src, idx)
264
- validate_triton_block_size(triton_block_size, sparsity_block_size)
265
277
 
266
278
  if reduce_op not in ["none", "sum"]:
267
279
  raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
268
280
 
269
281
  adjusted_dim = dim % 3
270
282
 
271
- lut = _BlocksparseScatterReduce.build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
272
-
273
- return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, lut["sparsity_lut_x"],
274
- adjusted_dim, idx,
275
- sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
276
- sparsity_block_size, lut["n_sparse_blocks"],
277
- reduce_op, triton_block_size))
278
-
279
-
280
- class _BlocksparseScatterReduce(torch.autograd.Function):
281
-
282
- @staticmethod
283
- def build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
284
- if lut is None:
285
- lut = dict()
283
+ lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
286
284
 
287
- if "sparsity_lut_x" not in lut:
288
- sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
289
- lut["sparsity_lut_x"] = sparsity_lut_x
285
+ return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
286
+ adjusted_dim, idx,
287
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
288
+ sparsity_block_size, lut["n_sparse_blocks"],
289
+ reduce_op))
290
290
 
291
- if "sparsity_reverse_lut_o" not in lut:
292
- sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
293
- sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
294
- (sparsity_layout_o_flat == 1) -
295
- (1 * (sparsity_layout_o_flat == 0)))
296
- lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
297
291
 
298
- if "n_sparse_blocks" not in lut:
299
- n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
300
- lut["n_sparse_blocks"] = n_sparse_blocks
301
-
302
- validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
303
- sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
304
-
305
- return lut
306
-
307
- @staticmethod
308
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
309
- dim: int, i: Tensor,
310
- sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
311
- sparsity_block_size: int, n_sparse_blocks: int,
312
- reduce_op: str, triton_block_size: int) -> Tensor:
292
+ @triton_op("blksprs::scatter_reduce_forward", mutates_args={})
293
+ def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
294
+ dim: int, i: Tensor,
295
+ sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
296
+ sparsity_block_size: int, n_sparse_blocks: int,
297
+ reduce_op: str) -> Tensor:
298
+ with torch.no_grad():
313
299
  output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
314
300
  dtype=x.dtype, device=x.device)
315
301
 
@@ -324,9 +310,6 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
324
310
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
325
311
  s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
326
312
 
327
- if triton_block_size is None:
328
- triton_block_size = get_triton_block_size(sparsity_block_size)
329
-
330
313
  triton_grid = lambda meta: [x_b,
331
314
  triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
332
315
  triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
@@ -335,7 +318,7 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
335
318
  if reduce_op == "sum":
336
319
  reduce_op_ind = 1
337
320
 
338
- (_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]
321
+ (wrap_triton(scatter_reduce_kernel)[triton_grid]
339
322
  (x,
340
323
  x_b, x_b_s, x_r_s, x_c_s,
341
324
  sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
@@ -347,112 +330,153 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
347
330
  s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
348
331
  sparsity_reverse_lut_o,
349
332
  reduce_op_ind,
350
- sparsity_block_size,
351
- triton_block_size))
352
-
353
- ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
354
- ctx.dim = dim
355
- ctx.sparsity_block_size = sparsity_block_size
356
- ctx.reduce_op = reduce_op
357
- ctx.triton_block_size = triton_block_size
333
+ sparsity_block_size))
358
334
 
359
335
  return output
360
336
 
361
- @staticmethod
362
- def backward(ctx, grad_output):
363
- sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
364
- dim = ctx.dim
365
- sparsity_block_size = ctx.sparsity_block_size
366
- reduce_op = ctx.reduce_op
367
- triton_block_size = ctx.triton_block_size
368
337
 
369
- if reduce_op == "sum":
370
- return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x, sparsity_block_size,
371
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
372
- else:
373
- raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
374
-
375
- @staticmethod
376
- @triton.jit
377
- def kernel_blocksparse_scatter(x,
378
- x_b, x_b_s, x_r_s, x_c_s,
379
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
380
- dim,
381
- i,
382
- i_b, i_b_s, i_r_s, i_c_s,
383
- o,
384
- o_b, o_b_s,
385
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
386
- r_lut_o,
387
- reduce_op_ind,
388
- sparsity_block_size,
389
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
390
- # Get triton block indices
391
- pid_blk = tl.program_id(axis=0)
392
- pid_row = tl.program_id(axis=1)
393
- pid_col = tl.program_id(axis=2)
394
-
395
- # Get position of current sparsity block consisting of its batch, row, and column index
396
- spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
397
- spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
398
- spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
399
-
400
- spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
401
- spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
402
- spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
403
-
404
- spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
405
- spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
406
- spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
407
-
408
- # Load x values
409
- blk_x_idx = ((pid_blk * x_b_s) +
410
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
411
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
412
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
413
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
414
-
415
- # Load index values
416
- blk_i_idx = ((pid_blk * i_b_s) +
417
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
418
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
419
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
420
- blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
421
-
422
- # Get indices of sparsity blocks and positions within the blocks
423
- pos_spa_blk_x = blk_i // sparsity_block_size
424
- pos_spa_int_x = blk_i % sparsity_block_size
425
-
426
- rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
427
- rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
428
- rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
429
- dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
430
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
431
- dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
432
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
433
- if dim == 0:
434
- rev_dst_bat_o = blk_i
435
- elif dim == 1:
436
- rev_dst_row_o = pos_spa_blk_x
437
- dst_row_o = pos_spa_int_x * x_r_s
438
- elif dim == 2:
439
- rev_dst_col_o = pos_spa_blk_x
440
- dst_col_o = pos_spa_int_x * x_c_s
441
-
442
- # Load reverse sparsity indices for o
443
- rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
444
- (rev_dst_row_o * s_l_o_r_s) +
445
- (rev_dst_col_o * s_l_o_c_s))
446
- rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
447
- rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
448
-
449
- # Store output
450
- blk_o_idx = ((rev_idx_spa_o * o_b_s) +
451
- dst_row_o +
452
- dst_col_o)
453
- blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_o_msk != -1)
454
-
455
- if reduce_op_ind == 0:
456
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
457
- elif reduce_op_ind == 1:
458
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
338
+ def scatter_reduce_wrapper_backward(ctx, grad_output):
339
+ sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
340
+ dim = ctx.dim
341
+ sparsity_block_size = ctx.sparsity_block_size
342
+ reduce_op = ctx.reduce_op
343
+
344
+ if reduce_op == "sum":
345
+ return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x,
346
+ sparsity_block_size), None, None, None, None, None, None, None, None, None
347
+ else:
348
+ raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
349
+
350
+
351
+ @triton.autotune(
352
+ configs=get_autotune_configs(),
353
+ key=["sparsity_block_size"],
354
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
355
+ reset_to_zero=["o"]
356
+ )
357
+ @triton.jit
358
+ def scatter_reduce_kernel(x,
359
+ x_b, x_b_s, x_r_s, x_c_s,
360
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
361
+ dim,
362
+ i,
363
+ i_b, i_b_s, i_r_s, i_c_s,
364
+ o,
365
+ o_b, o_b_s,
366
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
367
+ r_lut_o,
368
+ reduce_op_ind,
369
+ sparsity_block_size,
370
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
371
+ # Get triton block indices
372
+ pid_blk = tl.program_id(axis=0)
373
+ pid_row = tl.program_id(axis=1)
374
+ pid_col = tl.program_id(axis=2)
375
+
376
+ # Get position of current sparsity block consisting of its batch, row, and column index
377
+ spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
378
+ spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
379
+ spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
380
+
381
+ spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
382
+ spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
383
+ spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
384
+
385
+ spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
386
+ spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
387
+ spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
388
+
389
+ # Load x values
390
+ blk_x_idx = ((pid_blk * x_b_s) +
391
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
392
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
393
+ blk_x_msk = (blk_x_idx >= 0 and
394
+ blk_x_idx < x_b * x_b_s)
395
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
396
+
397
+ # Load index values
398
+ blk_i_idx = ((pid_blk * i_b_s) +
399
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
400
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
401
+ blk_i_msk = (blk_i_idx >= 0 and
402
+ blk_i_idx < i_b * i_b_s)
403
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
404
+
405
+ # Get indices of sparsity blocks and positions within the blocks
406
+ pos_spa_blk_x = blk_i // sparsity_block_size
407
+ pos_spa_int_x = blk_i % sparsity_block_size
408
+
409
+ rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
410
+ rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
411
+ rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
412
+ dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
413
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
414
+ dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
415
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
416
+ if dim == 0:
417
+ rev_dst_bat_o = blk_i
418
+ elif dim == 1:
419
+ rev_dst_row_o = pos_spa_blk_x
420
+ dst_row_o = pos_spa_int_x * x_r_s
421
+ elif dim == 2:
422
+ rev_dst_col_o = pos_spa_blk_x
423
+ dst_col_o = pos_spa_int_x * x_c_s
424
+
425
+ # Load reverse sparsity indices for o
426
+ rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
427
+ (rev_dst_row_o * s_l_o_r_s) +
428
+ (rev_dst_col_o * s_l_o_c_s))
429
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
430
+ rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
431
+ rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
432
+
433
+ # Store output
434
+ blk_o_idx = ((rev_idx_spa_o * o_b_s) +
435
+ dst_row_o +
436
+ dst_col_o)
437
+ blk_o_msk = ((blk_o_idx >= 0 and
438
+ blk_o_idx < o_b * o_b_s) and
439
+ rev_idx_spa_o_msk != -1)
440
+
441
+ if reduce_op_ind == 0:
442
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
443
+ elif reduce_op_ind == 1:
444
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
445
+
446
+
447
+ def scatter_reduce_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
448
+ if lut is None:
449
+ lut = dict()
450
+
451
+ if "sparsity_lut_x" not in lut:
452
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
453
+ lut["sparsity_lut_x"] = sparsity_lut_x
454
+
455
+ if "sparsity_reverse_lut_o" not in lut:
456
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
457
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
458
+ (sparsity_layout_o_flat == 1) -
459
+ (1 * (sparsity_layout_o_flat == 0)))
460
+ lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
461
+
462
+ if "n_sparse_blocks" not in lut:
463
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
464
+ lut["n_sparse_blocks"] = n_sparse_blocks
465
+
466
+ validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
467
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
468
+
469
+ return lut
470
+
471
+
472
+ # noinspection PyUnusedLocal
473
+ def scatter_reduce_setup_context(ctx, inputs, output):
474
+ (_, sparsity_layout_x, _, dim, i, sparsity_layout_o, _, sparsity_block_size, _, reduce_op) = inputs
475
+
476
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
477
+ ctx.dim = dim
478
+ ctx.sparsity_block_size = sparsity_block_size
479
+ ctx.reduce_op = reduce_op
480
+
481
+
482
+ scatter_reduce_forward.register_autograd(scatter_reduce_wrapper_backward, setup_context=scatter_reduce_setup_context)