blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,22 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.ops.conversion import to_dense
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
12
+ validate_sparsity, validate_dtype_int, validate_sparsity_block_size
11
13
 
12
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
13
16
  def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
14
17
  dim: int,
15
18
  idx: BlksprsTensor, sparsity_layout_idx: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
19
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
17
20
  """Applies a gather operation on a block-sparse tensor in compressed form.
18
21
 
19
22
  Args:
@@ -23,7 +26,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
23
26
  idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
24
27
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
25
28
  sparsity_block_size (int): The size of the sparsity blocks.
26
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
29
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
27
30
 
28
31
  Returns:
29
32
  BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
@@ -38,32 +41,22 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
38
41
  validate_device(src, idx)
39
42
  validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
40
43
  validate_sparsity_block_size(sparsity_block_size, src, idx)
41
- validate_triton_block_size(triton_block_size, sparsity_block_size)
42
-
43
- sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
44
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
45
- (sparsity_layout_x_flat == 1) -
46
- (1 * (sparsity_layout_x_flat == 0)))
47
-
48
- sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
49
-
50
- validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
51
- sparsity_layout_idx, sparsity_lut_i)
52
44
 
53
45
  adjusted_dim = dim % 3
54
46
 
55
- return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
56
- adjusted_dim, idx, sparsity_layout_idx, sparsity_lut_i,
57
- sparsity_block_size, triton_block_size))
47
+ lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
58
48
 
49
+ return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
50
+ adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
51
+ sparsity_block_size))
59
52
 
60
- class _BlocksparseGather(torch.autograd.Function):
61
53
 
62
- @staticmethod
63
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
64
- dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
65
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
66
- output = torch.empty_like(i, dtype=x.dtype)
54
+ @triton_op("blksprs::gather_forward", mutates_args={})
55
+ def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
56
+ dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
57
+ sparsity_block_size: int) -> Tensor:
58
+ with torch.no_grad():
59
+ output = torch.zeros_like(i, dtype=x.dtype)
67
60
 
68
61
  x_b, x_r, x_c = x.size()
69
62
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -76,14 +69,11 @@ class _BlocksparseGather(torch.autograd.Function):
76
69
  o_b, o_r, o_c = output.size()
77
70
  o_b_s, o_r_s, o_c_s = stride(output)
78
71
 
79
- if triton_block_size is None:
80
- triton_block_size = get_triton_block_size(sparsity_block_size)
81
-
82
72
  triton_grid = lambda meta: [o_b,
83
73
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
84
74
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
85
75
 
86
- (_BlocksparseGather.kernel_blocksparse_gather[triton_grid]
76
+ (wrap_triton(gather_kernel)[triton_grid]
87
77
  (x,
88
78
  x_b, x_b_s, x_r_s, x_c_s,
89
79
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
@@ -94,115 +84,152 @@ class _BlocksparseGather(torch.autograd.Function):
94
84
  output,
95
85
  o_b, o_b_s, o_r_s, o_c_s,
96
86
  sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
97
- sparsity_block_size,
98
- triton_block_size))
99
-
100
- ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
101
- ctx.dim = dim
102
- ctx.sparsity_block_size = sparsity_block_size
103
- ctx.triton_block_size = triton_block_size
87
+ sparsity_block_size))
104
88
 
105
89
  return output
106
90
 
107
- @staticmethod
108
- def backward(ctx, grad_output):
109
- sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
110
- dim = ctx.dim
111
- sparsity_block_size = ctx.sparsity_block_size
112
- triton_block_size = ctx.triton_block_size
113
-
114
- return scatter_reduce(grad_output, sparsity_layout_i,
115
- dim, i,
116
- sparsity_layout_x, sparsity_block_size,
117
- reduce_op="sum",
118
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
119
-
120
- @staticmethod
121
- @triton.jit
122
- def kernel_blocksparse_gather(x,
123
- x_b, x_b_s, x_r_s, x_c_s,
124
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
125
- r_lut_x,
126
- dim,
127
- i,
128
- i_b, i_b_s, i_r_s, i_c_s,
129
- o,
130
- o_b, o_b_s, o_r_s, o_c_s,
131
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
132
- sparsity_block_size,
133
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
134
- # Get triton block indices
135
- pid_blk = tl.program_id(axis=0)
136
- pid_row = tl.program_id(axis=1)
137
- pid_col = tl.program_id(axis=2)
138
-
139
- # Get position of current sparsity block consisting of its batch, row, and column index
140
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
141
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
142
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
143
-
144
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
145
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
146
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
147
-
148
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
149
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
150
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
151
-
152
- # Load index values
153
- blk_i_idx = ((pid_blk * i_b_s) +
154
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
155
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
156
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
157
- blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
158
-
159
- # Get indices of sparsity blocks and positions within the blocks
160
- pos_spa_blk_x = blk_i // sparsity_block_size
161
- pos_spa_int_x = blk_i % sparsity_block_size
162
-
163
- rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
164
- rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
165
- rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
166
- dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
167
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
168
- dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
169
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
170
- if dim == 0:
171
- rev_dst_bat_x = blk_i
172
- elif dim == 1:
173
- rev_dst_row_x = pos_spa_blk_x
174
- dst_row_x = pos_spa_int_x * x_r_s
175
- elif dim == 2:
176
- rev_dst_col_x = pos_spa_blk_x
177
- dst_col_x = pos_spa_int_x * x_c_s
178
-
179
- # Load reverse sparsity indices for x
180
- rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
181
- (rev_dst_row_x * s_l_x_r_s) +
182
- (rev_dst_col_x * s_l_x_c_s))
183
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
184
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
185
-
186
- # Load x values
187
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
188
- dst_row_x +
189
- dst_col_x)
190
- blk_x_msk = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s) and rev_idx_spa_x_msk != -1)
191
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
192
-
193
- # Store output
194
- blk_o_idx = ((pid_blk * o_b_s) +
195
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
196
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
197
- blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_x_msk != -1)
198
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
91
+
92
+ def gather_wrapper_backward(ctx, grad_output):
93
+ sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
94
+ dim = ctx.dim
95
+ sparsity_block_size = ctx.sparsity_block_size
96
+
97
+ return scatter_reduce(grad_output, sparsity_layout_i,
98
+ dim, i,
99
+ sparsity_layout_x, sparsity_block_size,
100
+ reduce_op="sum"), None, None, None, None, None, None, None
101
+
102
+
103
+ @triton.autotune(
104
+ configs=get_autotune_configs(),
105
+ key=["sparsity_block_size"],
106
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
107
+ reset_to_zero=["o"]
108
+ )
109
+ @triton.jit
110
+ def gather_kernel(x,
111
+ x_b, x_b_s, x_r_s, x_c_s,
112
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
113
+ r_lut_x,
114
+ dim,
115
+ i,
116
+ i_b, i_b_s, i_r_s, i_c_s,
117
+ o,
118
+ o_b, o_b_s, o_r_s, o_c_s,
119
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
120
+ sparsity_block_size,
121
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
122
+ # Get triton block indices
123
+ pid_blk = tl.program_id(axis=0)
124
+ pid_row = tl.program_id(axis=1)
125
+ pid_col = tl.program_id(axis=2)
126
+
127
+ # Get position of current sparsity block consisting of its batch, row, and column index
128
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
129
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
130
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
131
+
132
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
133
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
134
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
135
+
136
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
137
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
138
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
139
+
140
+ # Load index values
141
+ blk_i_idx = ((pid_blk * i_b_s) +
142
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
143
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
144
+ blk_i_msk = (blk_i_idx >= 0 and
145
+ blk_i_idx < i_b * i_b_s)
146
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
147
+
148
+ # Get indices of sparsity blocks and positions within the blocks
149
+ pos_spa_blk_x = blk_i // sparsity_block_size
150
+ pos_spa_int_x = blk_i % sparsity_block_size
151
+
152
+ rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
153
+ rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
154
+ rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
155
+ dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
156
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
157
+ dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
158
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
159
+ if dim == 0:
160
+ rev_dst_bat_x = blk_i
161
+ elif dim == 1:
162
+ rev_dst_row_x = pos_spa_blk_x
163
+ dst_row_x = pos_spa_int_x * x_r_s
164
+ elif dim == 2:
165
+ rev_dst_col_x = pos_spa_blk_x
166
+ dst_col_x = pos_spa_int_x * x_c_s
167
+
168
+ # Load reverse sparsity indices for x
169
+ rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
170
+ (rev_dst_row_x * s_l_x_r_s) +
171
+ (rev_dst_col_x * s_l_x_c_s))
172
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
173
+ rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
174
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
175
+
176
+ # Load x values
177
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
178
+ dst_row_x +
179
+ dst_col_x)
180
+ blk_x_msk = ((blk_x_idx >= 0 and
181
+ blk_x_idx < x_b * x_b_s) and
182
+ rev_idx_spa_x_msk != -1)
183
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
184
+
185
+ # Store output
186
+ blk_o_idx = ((pid_blk * o_b_s) +
187
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
188
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
189
+ blk_o_msk = ((blk_o_idx >= 0 and
190
+ blk_o_idx < o_b * o_b_s) and
191
+ rev_idx_spa_x_msk != -1)
192
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
193
+
194
+
195
+ def gather_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
196
+ if lut is None:
197
+ lut = dict()
198
+
199
+ if "sparsity_reverse_lut_x" not in lut:
200
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
201
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
202
+ (sparsity_layout_x_flat == 1) -
203
+ (1 * (sparsity_layout_x_flat == 0)))
204
+ lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
205
+
206
+ if "sparsity_lut_i" not in lut:
207
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
208
+ lut["sparsity_lut_i"] = sparsity_lut_i
209
+
210
+ validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
211
+ sparsity_layout_idx, lut["sparsity_lut_i"])
212
+
213
+ return lut
214
+
215
+
216
+ # noinspection PyUnusedLocal
217
+ def gather_setup_context(ctx, inputs, output):
218
+ (_, sparsity_layout_x, _, dim, i, sparsity_layout_i, _, sparsity_block_size) = inputs
219
+
220
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
221
+ ctx.dim = dim
222
+ ctx.sparsity_block_size = sparsity_block_size
223
+
224
+
225
+ gather_forward.register_autograd(gather_wrapper_backward, setup_context=gather_setup_context)
199
226
 
200
227
 
201
228
  def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
202
229
  dim: int,
203
230
  idx: BlksprsTensor,
204
231
  sparsity_layout_tgt: Tensor,
205
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
232
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
206
233
  """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
207
234
 
208
235
  """
@@ -211,15 +238,16 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
211
238
  idx,
212
239
  sparsity_layout_tgt,
213
240
  sparsity_block_size,
214
- reduce_op="none", triton_block_size=triton_block_size)
241
+ reduce_op="none", lut=lut)
215
242
 
216
243
 
244
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
217
245
  def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
218
246
  dim: int,
219
247
  idx: BlksprsTensor,
220
248
  sparsity_layout_tgt: Tensor,
221
249
  sparsity_block_size: int,
222
- reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
250
+ reduce_op: str = "sum", lut: dict = None) -> BlksprsTensor:
223
251
  """Applies a scatter operation on a block-sparse tensor in compressed form.
224
252
 
225
253
  Args:
@@ -231,7 +259,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
231
259
  sparsity_block_size (int): The size of the sparsity blocks.
232
260
  reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
233
261
  Supported operations are ``"none"`` and ``"sum"``.
234
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
262
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
235
263
 
236
264
  Returns:
237
265
  BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
@@ -246,40 +274,28 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
246
274
  validate_device(src, idx)
247
275
  validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
248
276
  validate_sparsity_block_size(sparsity_block_size, src, idx)
249
- validate_triton_block_size(triton_block_size, sparsity_block_size)
250
277
 
251
278
  if reduce_op not in ["none", "sum"]:
252
279
  raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
253
280
 
254
- sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
255
-
256
- sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
257
- sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
258
- (sparsity_layout_o_flat == 1) -
259
- (1 * (sparsity_layout_o_flat == 0)))
260
-
261
- n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
262
-
263
- validate_contiguous(sparsity_layout_src, sparsity_lut_x,
264
- sparsity_layout_tgt, sparsity_reverse_lut_o)
265
-
266
281
  adjusted_dim = dim % 3
267
282
 
268
- return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
269
- adjusted_dim, idx,
270
- sparsity_layout_tgt, sparsity_reverse_lut_o,
271
- sparsity_block_size, n_sparse_blocks,
272
- reduce_op, triton_block_size))
283
+ lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
273
284
 
285
+ return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
286
+ adjusted_dim, idx,
287
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
288
+ sparsity_block_size, lut["n_sparse_blocks"],
289
+ reduce_op))
274
290
 
275
- class _BlocksparseScatterReduce(torch.autograd.Function):
276
291
 
277
- @staticmethod
278
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
279
- dim: int, i: Tensor,
280
- sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
281
- sparsity_block_size: int, n_sparse_blocks: int,
282
- reduce_op: str, triton_block_size: int) -> Tensor:
292
+ @triton_op("blksprs::scatter_reduce_forward", mutates_args={})
293
+ def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
294
+ dim: int, i: Tensor,
295
+ sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
296
+ sparsity_block_size: int, n_sparse_blocks: int,
297
+ reduce_op: str) -> Tensor:
298
+ with torch.no_grad():
283
299
  output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
284
300
  dtype=x.dtype, device=x.device)
285
301
 
@@ -294,9 +310,6 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
294
310
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
295
311
  s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
296
312
 
297
- if triton_block_size is None:
298
- triton_block_size = get_triton_block_size(sparsity_block_size)
299
-
300
313
  triton_grid = lambda meta: [x_b,
301
314
  triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
302
315
  triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
@@ -305,7 +318,7 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
305
318
  if reduce_op == "sum":
306
319
  reduce_op_ind = 1
307
320
 
308
- (_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]
321
+ (wrap_triton(scatter_reduce_kernel)[triton_grid]
309
322
  (x,
310
323
  x_b, x_b_s, x_r_s, x_c_s,
311
324
  sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
@@ -317,112 +330,153 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
317
330
  s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
318
331
  sparsity_reverse_lut_o,
319
332
  reduce_op_ind,
320
- sparsity_block_size,
321
- triton_block_size))
322
-
323
- ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
324
- ctx.dim = dim
325
- ctx.sparsity_block_size = sparsity_block_size
326
- ctx.reduce_op = reduce_op
327
- ctx.triton_block_size = triton_block_size
333
+ sparsity_block_size))
328
334
 
329
335
  return output
330
336
 
331
- @staticmethod
332
- def backward(ctx, grad_output):
333
- sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
334
- dim = ctx.dim
335
- sparsity_block_size = ctx.sparsity_block_size
336
- reduce_op = ctx.reduce_op
337
- triton_block_size = ctx.triton_block_size
338
337
 
339
- if reduce_op == "sum":
340
- return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x, sparsity_block_size,
341
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
342
- else:
343
- raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
344
-
345
- @staticmethod
346
- @triton.jit
347
- def kernel_blocksparse_scatter(x,
348
- x_b, x_b_s, x_r_s, x_c_s,
349
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
350
- dim,
351
- i,
352
- i_b, i_b_s, i_r_s, i_c_s,
353
- o,
354
- o_b, o_b_s,
355
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
356
- r_lut_o,
357
- reduce_op_ind,
358
- sparsity_block_size,
359
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
360
- # Get triton block indices
361
- pid_blk = tl.program_id(axis=0)
362
- pid_row = tl.program_id(axis=1)
363
- pid_col = tl.program_id(axis=2)
364
-
365
- # Get position of current sparsity block consisting of its batch, row, and column index
366
- spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
367
- spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
368
- spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
369
-
370
- spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
371
- spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
372
- spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
373
-
374
- spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
375
- spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
376
- spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
377
-
378
- # Load x values
379
- blk_x_idx = ((pid_blk * x_b_s) +
380
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
381
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
382
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
383
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
384
-
385
- # Load index values
386
- blk_i_idx = ((pid_blk * i_b_s) +
387
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
388
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
389
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
390
- blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
391
-
392
- # Get indices of sparsity blocks and positions within the blocks
393
- pos_spa_blk_x = blk_i // sparsity_block_size
394
- pos_spa_int_x = blk_i % sparsity_block_size
395
-
396
- rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
397
- rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
398
- rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
399
- dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
400
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
401
- dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
402
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
403
- if dim == 0:
404
- rev_dst_bat_o = blk_i
405
- elif dim == 1:
406
- rev_dst_row_o = pos_spa_blk_x
407
- dst_row_o = pos_spa_int_x * x_r_s
408
- elif dim == 2:
409
- rev_dst_col_o = pos_spa_blk_x
410
- dst_col_o = pos_spa_int_x * x_c_s
411
-
412
- # Load reverse sparsity indices for o
413
- rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
414
- (rev_dst_row_o * s_l_o_r_s) +
415
- (rev_dst_col_o * s_l_o_c_s))
416
- rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
417
- rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
418
-
419
- # Store output
420
- blk_o_idx = ((rev_idx_spa_o * o_b_s) +
421
- dst_row_o +
422
- dst_col_o)
423
- blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_o_msk != -1)
424
-
425
- if reduce_op_ind == 0:
426
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
427
- elif reduce_op_ind == 1:
428
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
338
+ def scatter_reduce_wrapper_backward(ctx, grad_output):
339
+ sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
340
+ dim = ctx.dim
341
+ sparsity_block_size = ctx.sparsity_block_size
342
+ reduce_op = ctx.reduce_op
343
+
344
+ if reduce_op == "sum":
345
+ return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x,
346
+ sparsity_block_size), None, None, None, None, None, None, None, None, None
347
+ else:
348
+ raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
349
+
350
+
351
+ @triton.autotune(
352
+ configs=get_autotune_configs(),
353
+ key=["sparsity_block_size"],
354
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
355
+ reset_to_zero=["o"]
356
+ )
357
+ @triton.jit
358
+ def scatter_reduce_kernel(x,
359
+ x_b, x_b_s, x_r_s, x_c_s,
360
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
361
+ dim,
362
+ i,
363
+ i_b, i_b_s, i_r_s, i_c_s,
364
+ o,
365
+ o_b, o_b_s,
366
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
367
+ r_lut_o,
368
+ reduce_op_ind,
369
+ sparsity_block_size,
370
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
371
+ # Get triton block indices
372
+ pid_blk = tl.program_id(axis=0)
373
+ pid_row = tl.program_id(axis=1)
374
+ pid_col = tl.program_id(axis=2)
375
+
376
+ # Get position of current sparsity block consisting of its batch, row, and column index
377
+ spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
378
+ spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
379
+ spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
380
+
381
+ spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
382
+ spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
383
+ spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
384
+
385
+ spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
386
+ spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
387
+ spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
388
+
389
+ # Load x values
390
+ blk_x_idx = ((pid_blk * x_b_s) +
391
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
392
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
393
+ blk_x_msk = (blk_x_idx >= 0 and
394
+ blk_x_idx < x_b * x_b_s)
395
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
396
+
397
+ # Load index values
398
+ blk_i_idx = ((pid_blk * i_b_s) +
399
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
400
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
401
+ blk_i_msk = (blk_i_idx >= 0 and
402
+ blk_i_idx < i_b * i_b_s)
403
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
404
+
405
+ # Get indices of sparsity blocks and positions within the blocks
406
+ pos_spa_blk_x = blk_i // sparsity_block_size
407
+ pos_spa_int_x = blk_i % sparsity_block_size
408
+
409
+ rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
410
+ rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
411
+ rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
412
+ dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
413
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
414
+ dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
415
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
416
+ if dim == 0:
417
+ rev_dst_bat_o = blk_i
418
+ elif dim == 1:
419
+ rev_dst_row_o = pos_spa_blk_x
420
+ dst_row_o = pos_spa_int_x * x_r_s
421
+ elif dim == 2:
422
+ rev_dst_col_o = pos_spa_blk_x
423
+ dst_col_o = pos_spa_int_x * x_c_s
424
+
425
+ # Load reverse sparsity indices for o
426
+ rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
427
+ (rev_dst_row_o * s_l_o_r_s) +
428
+ (rev_dst_col_o * s_l_o_c_s))
429
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
430
+ rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
431
+ rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
432
+
433
+ # Store output
434
+ blk_o_idx = ((rev_idx_spa_o * o_b_s) +
435
+ dst_row_o +
436
+ dst_col_o)
437
+ blk_o_msk = ((blk_o_idx >= 0 and
438
+ blk_o_idx < o_b * o_b_s) and
439
+ rev_idx_spa_o_msk != -1)
440
+
441
+ if reduce_op_ind == 0:
442
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
443
+ elif reduce_op_ind == 1:
444
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
445
+
446
+
447
+ def scatter_reduce_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
448
+ if lut is None:
449
+ lut = dict()
450
+
451
+ if "sparsity_lut_x" not in lut:
452
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
453
+ lut["sparsity_lut_x"] = sparsity_lut_x
454
+
455
+ if "sparsity_reverse_lut_o" not in lut:
456
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
457
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
458
+ (sparsity_layout_o_flat == 1) -
459
+ (1 * (sparsity_layout_o_flat == 0)))
460
+ lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
461
+
462
+ if "n_sparse_blocks" not in lut:
463
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
464
+ lut["n_sparse_blocks"] = n_sparse_blocks
465
+
466
+ validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
467
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
468
+
469
+ return lut
470
+
471
+
472
+ # noinspection PyUnusedLocal
473
+ def scatter_reduce_setup_context(ctx, inputs, output):
474
+ (_, sparsity_layout_x, _, dim, i, sparsity_layout_o, _, sparsity_block_size, _, reduce_op) = inputs
475
+
476
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
477
+ ctx.dim = dim
478
+ ctx.sparsity_block_size = sparsity_block_size
479
+ ctx.reduce_op = reduce_op
480
+
481
+
482
+ scatter_reduce_forward.register_autograd(scatter_reduce_wrapper_backward, setup_context=scatter_reduce_setup_context)