blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,20 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.ops.conversion import to_dense
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride, get_autotune_configs
9
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
11
+ validate_sparsity, validate_dtype_int, validate_sparsity_block_size
11
12
 
12
13
 
13
14
  def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
14
15
  dim: int,
15
16
  idx: BlksprsTensor, sparsity_layout_idx: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
17
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
17
18
  """Applies a gather operation on a block-sparse tensor in compressed form.
18
19
 
19
20
  Args:
@@ -23,7 +24,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
23
24
  idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
24
25
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
25
26
  sparsity_block_size (int): The size of the sparsity blocks.
26
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
27
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
27
28
 
28
29
  Returns:
29
30
  BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
@@ -38,171 +39,203 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
38
39
  validate_device(src, idx)
39
40
  validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
40
41
  validate_sparsity_block_size(sparsity_block_size, src, idx)
41
- validate_triton_block_size(triton_block_size, sparsity_block_size)
42
-
43
- sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
44
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
45
- (sparsity_layout_x_flat == 1) -
46
- (1 * (sparsity_layout_x_flat == 0)))
47
-
48
- sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
49
-
50
- validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
51
- sparsity_layout_idx, sparsity_lut_i)
52
42
 
53
43
  adjusted_dim = dim % 3
54
44
 
55
- return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
56
- adjusted_dim, idx, sparsity_layout_idx, sparsity_lut_i,
57
- sparsity_block_size, triton_block_size))
58
-
59
-
60
- class _BlocksparseGather(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
64
- dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
65
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
66
- output = torch.empty_like(i, dtype=x.dtype)
67
-
68
- x_b, x_r, x_c = x.size()
69
- x_b_s, x_r_s, x_c_s = stride(x)
70
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
71
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
72
- i_b, i_r, i_c = i.size()
73
- i_b_s, i_r_s, i_c_s = stride(i)
74
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
75
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
76
- o_b, o_r, o_c = output.size()
77
- o_b_s, o_r_s, o_c_s = stride(output)
78
-
79
- if triton_block_size is None:
80
- triton_block_size = get_triton_block_size(sparsity_block_size)
81
-
82
- triton_grid = lambda meta: [o_b,
83
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
84
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
85
-
86
- (_BlocksparseGather.kernel_blocksparse_gather[triton_grid]
87
- (x,
88
- x_b, x_b_s, x_r_s, x_c_s,
89
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
90
- sparsity_reverse_lut_x,
91
- dim,
92
- i,
93
- i_b, i_b_s, i_r_s, i_c_s,
94
- output,
95
- o_b, o_b_s, o_r_s, o_c_s,
96
- sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
97
- sparsity_block_size,
98
- triton_block_size))
99
-
100
- ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
101
- ctx.dim = dim
102
- ctx.sparsity_block_size = sparsity_block_size
103
- ctx.triton_block_size = triton_block_size
104
-
105
- return output
106
-
107
- @staticmethod
108
- def backward(ctx, grad_output):
109
- sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
110
- dim = ctx.dim
111
- sparsity_block_size = ctx.sparsity_block_size
112
- triton_block_size = ctx.triton_block_size
113
-
114
- return scatter_reduce(grad_output, sparsity_layout_i,
115
- dim, i,
116
- sparsity_layout_x, sparsity_block_size,
117
- reduce_op="sum",
118
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
119
-
120
- @staticmethod
121
- @triton.jit
122
- def kernel_blocksparse_gather(x,
123
- x_b, x_b_s, x_r_s, x_c_s,
124
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
125
- r_lut_x,
126
- dim,
127
- i,
128
- i_b, i_b_s, i_r_s, i_c_s,
129
- o,
130
- o_b, o_b_s, o_r_s, o_c_s,
131
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
132
- sparsity_block_size,
133
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
134
- # Get triton block indices
135
- pid_blk = tl.program_id(axis=0)
136
- pid_row = tl.program_id(axis=1)
137
- pid_col = tl.program_id(axis=2)
138
-
139
- # Get position of current sparsity block consisting of its batch, row, and column index
140
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
141
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
142
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
143
-
144
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
145
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
146
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
147
-
148
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
149
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
150
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
151
-
152
- # Load index values
153
- blk_i_idx = ((pid_blk * i_b_s) +
154
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
155
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
156
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
157
- blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
158
-
159
- # Get indices of sparsity blocks and positions within the blocks
160
- pos_spa_blk_x = blk_i // sparsity_block_size
161
- pos_spa_int_x = blk_i % sparsity_block_size
162
-
163
- rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
164
- rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
165
- rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
166
- dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
167
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
168
- dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
169
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
170
- if dim == 0:
171
- rev_dst_bat_x = blk_i
172
- elif dim == 1:
173
- rev_dst_row_x = pos_spa_blk_x
174
- dst_row_x = pos_spa_int_x * x_r_s
175
- elif dim == 2:
176
- rev_dst_col_x = pos_spa_blk_x
177
- dst_col_x = pos_spa_int_x * x_c_s
178
-
179
- # Load reverse sparsity indices for x
180
- rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
181
- (rev_dst_row_x * s_l_x_r_s) +
182
- (rev_dst_col_x * s_l_x_c_s))
183
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
184
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
185
-
186
- # Load x values
187
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
188
- dst_row_x +
189
- dst_col_x)
190
- blk_x_msk = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s) and rev_idx_spa_x_msk != -1)
191
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
192
-
193
- # Store output
194
- blk_o_idx = ((pid_blk * o_b_s) +
195
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
196
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
197
- blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_x_msk != -1)
198
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
45
+ lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
46
+
47
+ return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
48
+ adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
49
+ sparsity_block_size))
50
+
51
+
52
+ @triton_op("blksprs::gather", mutates_args={})
53
+ def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
54
+ dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
55
+ sparsity_block_size: int) -> Tensor:
56
+ output = torch.empty_like(i, dtype=x.dtype)
57
+
58
+ x_b, x_r, x_c = x.size()
59
+ x_b_s, x_r_s, x_c_s = stride(x)
60
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
61
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
62
+ i_b, i_r, i_c = i.size()
63
+ i_b_s, i_r_s, i_c_s = stride(i)
64
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
65
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
66
+ o_b, o_r, o_c = output.size()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
68
+
69
+ triton_grid = lambda meta: [o_b,
70
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
71
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
72
+
73
+ (wrap_triton(gather_kernel)[triton_grid]
74
+ (x,
75
+ x_b, x_b_s, x_r_s, x_c_s,
76
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
77
+ sparsity_reverse_lut_x,
78
+ dim,
79
+ i,
80
+ i_b, i_b_s, i_r_s, i_c_s,
81
+ output,
82
+ o_b, o_b_s, o_r_s, o_c_s,
83
+ sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
84
+ sparsity_block_size))
85
+
86
+ return output
87
+
88
+
89
+ def gather_backward(ctx, grad_output):
90
+ sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
91
+ dim = ctx.dim
92
+ sparsity_block_size = ctx.sparsity_block_size
93
+
94
+ return scatter_reduce(grad_output, sparsity_layout_i,
95
+ dim, i,
96
+ sparsity_layout_x, sparsity_block_size,
97
+ reduce_op="sum"), None, None, None, None, None, None, None
98
+
99
+
100
+ @triton.autotune(
101
+ configs=get_autotune_configs(),
102
+ key=[],
103
+ )
104
+ @triton.jit
105
+ def gather_kernel(x,
106
+ x_b, x_b_s, x_r_s, x_c_s,
107
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
108
+ r_lut_x,
109
+ dim,
110
+ i,
111
+ i_b, i_b_s, i_r_s, i_c_s,
112
+ o,
113
+ o_b, o_b_s, o_r_s, o_c_s,
114
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
115
+ sparsity_block_size,
116
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
117
+ # Get triton block indices
118
+ pid_blk = tl.program_id(axis=0)
119
+ pid_row = tl.program_id(axis=1)
120
+ pid_col = tl.program_id(axis=2)
121
+
122
+ # Get valid triton block size
123
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
124
+
125
+ # Get position of current sparsity block consisting of its batch, row, and column index
126
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
127
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
128
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
129
+
130
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
131
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
132
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
133
+
134
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
135
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
136
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
137
+
138
+ # Load index values
139
+ blk_i_idx = ((pid_blk * i_b_s) +
140
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
141
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
142
+ blk_i_msk = ((blk_i_idx >= 0 and
143
+ blk_i_idx < i_b * i_b_s) and
144
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
145
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
146
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
147
+
148
+ # Get indices of sparsity blocks and positions within the blocks
149
+ pos_spa_blk_x = blk_i // sparsity_block_size
150
+ pos_spa_int_x = blk_i % sparsity_block_size
151
+
152
+ rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
153
+ rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
154
+ rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
155
+ dst_row_x = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
156
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
157
+ dst_col_x = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
158
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
159
+ if dim == 0:
160
+ rev_dst_bat_x = blk_i
161
+ elif dim == 1:
162
+ rev_dst_row_x = pos_spa_blk_x
163
+ dst_row_x = pos_spa_int_x * x_r_s
164
+ elif dim == 2:
165
+ rev_dst_col_x = pos_spa_blk_x
166
+ dst_col_x = pos_spa_int_x * x_c_s
167
+
168
+ # Load reverse sparsity indices for x
169
+ rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
170
+ (rev_dst_row_x * s_l_x_r_s) +
171
+ (rev_dst_col_x * s_l_x_c_s))
172
+ rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0 and
173
+ rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s) and
174
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
175
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
176
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
177
+
178
+ # Load x values
179
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
180
+ dst_row_x +
181
+ dst_col_x)
182
+ blk_x_msk = (((blk_x_idx >= 0 and
183
+ blk_x_idx < x_b * x_b_s) and
184
+ rev_idx_spa_x_msk != -1) and
185
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
186
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
187
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
188
+
189
+ # Store output
190
+ blk_o_idx = ((pid_blk * o_b_s) +
191
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
192
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
193
+ blk_o_msk = (((blk_o_idx >= 0 and
194
+ blk_o_idx < o_b * o_b_s) and
195
+ rev_idx_spa_x_msk != -1) and
196
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
197
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
198
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
199
+
200
+
201
+ def gather_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
202
+ if lut is None:
203
+ lut = dict()
204
+
205
+ if "sparsity_reverse_lut_x" not in lut:
206
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
207
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
208
+ (sparsity_layout_x_flat == 1) -
209
+ (1 * (sparsity_layout_x_flat == 0)))
210
+ lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
211
+
212
+ if "sparsity_lut_i" not in lut:
213
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
214
+ lut["sparsity_lut_i"] = sparsity_lut_i
215
+
216
+ validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
217
+ sparsity_layout_idx, lut["sparsity_lut_i"])
218
+
219
+ return lut
220
+
221
+
222
+ # noinspection PyUnusedLocal
223
+ def gather_setup_context(ctx, inputs, output):
224
+ (_, sparsity_layout_x, _, dim, i, sparsity_layout_i, _, sparsity_block_size) = inputs
225
+
226
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
227
+ ctx.dim = dim
228
+ ctx.sparsity_block_size = sparsity_block_size
229
+
230
+
231
+ gather_forward.register_autograd(gather_backward, setup_context=gather_setup_context)
199
232
 
200
233
 
201
234
  def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
202
235
  dim: int,
203
236
  idx: BlksprsTensor,
204
237
  sparsity_layout_tgt: Tensor,
205
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
238
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
206
239
  """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
207
240
 
208
241
  """
@@ -211,7 +244,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
211
244
  idx,
212
245
  sparsity_layout_tgt,
213
246
  sparsity_block_size,
214
- reduce_op="none", triton_block_size=triton_block_size)
247
+ reduce_op="none", lut=lut)
215
248
 
216
249
 
217
250
  def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
@@ -219,7 +252,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
219
252
  idx: BlksprsTensor,
220
253
  sparsity_layout_tgt: Tensor,
221
254
  sparsity_block_size: int,
222
- reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
255
+ reduce_op: str = "sum", lut: dict = None) -> BlksprsTensor:
223
256
  """Applies a scatter operation on a block-sparse tensor in compressed form.
224
257
 
225
258
  Args:
@@ -231,7 +264,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
231
264
  sparsity_block_size (int): The size of the sparsity blocks.
232
265
  reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
233
266
  Supported operations are ``"none"`` and ``"sum"``.
234
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
267
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
235
268
 
236
269
  Returns:
237
270
  BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
@@ -246,183 +279,218 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
246
279
  validate_device(src, idx)
247
280
  validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
248
281
  validate_sparsity_block_size(sparsity_block_size, src, idx)
249
- validate_triton_block_size(triton_block_size, sparsity_block_size)
250
282
 
251
283
  if reduce_op not in ["none", "sum"]:
252
284
  raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
253
285
 
254
- sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
286
+ adjusted_dim = dim % 3
255
287
 
256
- sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
257
- sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
258
- (sparsity_layout_o_flat == 1) -
259
- (1 * (sparsity_layout_o_flat == 0)))
288
+ lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
289
+
290
+ return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
291
+ adjusted_dim, idx,
292
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
293
+ sparsity_block_size, lut["n_sparse_blocks"],
294
+ reduce_op))
295
+
296
+
297
+ @triton_op("blksprs::scatter_reduce", mutates_args={})
298
+ def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
299
+ dim: int, i: Tensor,
300
+ sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
301
+ sparsity_block_size: int, n_sparse_blocks: int,
302
+ reduce_op: str) -> Tensor:
303
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
304
+ dtype=x.dtype, device=x.device)
305
+
306
+ x_b, x_r, x_c = x.size()
307
+ x_b_s, x_r_s, x_c_s = stride(x)
308
+ s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
309
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
310
+ i_b, i_r, i_c = i.size()
311
+ i_b_s, i_r_s, i_c_s = stride(i)
312
+ o_b, o_r, o_c = output.size()
313
+ o_b_s, o_r_s, o_c_s = stride(output)
314
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
315
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
316
+
317
+ triton_grid = lambda meta: [x_b,
318
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
319
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
320
+
321
+ reduce_op_ind = 0
322
+ if reduce_op == "sum":
323
+ reduce_op_ind = 1
324
+
325
+ (wrap_triton(scatter_reduce_kernel)[triton_grid]
326
+ (x,
327
+ x_b, x_b_s, x_r_s, x_c_s,
328
+ sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
329
+ dim,
330
+ i,
331
+ i_b, i_b_s, i_r_s, i_c_s,
332
+ output,
333
+ o_b, o_b_s,
334
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
335
+ sparsity_reverse_lut_o,
336
+ reduce_op_ind,
337
+ sparsity_block_size))
338
+
339
+ return output
340
+
341
+
342
+ def scatter_reduce_backward(ctx, grad_output):
343
+ sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
344
+ dim = ctx.dim
345
+ sparsity_block_size = ctx.sparsity_block_size
346
+ reduce_op = ctx.reduce_op
347
+
348
+ if reduce_op == "sum":
349
+ return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x,
350
+ sparsity_block_size), None, None, None, None, None, None, None, None, None
351
+ else:
352
+ raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
353
+
354
+
355
+ @triton.autotune(
356
+ configs=get_autotune_configs(),
357
+ key=[],
358
+ reset_to_zero=["o"]
359
+ )
360
+ @triton.jit
361
+ def scatter_reduce_kernel(x,
362
+ x_b, x_b_s, x_r_s, x_c_s,
363
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
364
+ dim,
365
+ i,
366
+ i_b, i_b_s, i_r_s, i_c_s,
367
+ o,
368
+ o_b, o_b_s,
369
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
370
+ r_lut_o,
371
+ reduce_op_ind,
372
+ sparsity_block_size,
373
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
374
+ # Get triton block indices
375
+ pid_blk = tl.program_id(axis=0)
376
+ pid_row = tl.program_id(axis=1)
377
+ pid_col = tl.program_id(axis=2)
378
+
379
+ # Get valid triton block size
380
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
381
+
382
+ # Get position of current sparsity block consisting of its batch, row, and column index
383
+ spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
384
+ spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
385
+ spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
386
+
387
+ spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
388
+ spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
389
+ spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
390
+
391
+ spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
392
+ spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
393
+ spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
394
+
395
+ # Load x values
396
+ blk_x_idx = ((pid_blk * x_b_s) +
397
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
398
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
399
+ blk_x_msk = ((blk_x_idx >= 0 and
400
+ blk_x_idx < x_b * x_b_s) and
401
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
402
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
403
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
404
+
405
+ # Load index values
406
+ blk_i_idx = ((pid_blk * i_b_s) +
407
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
408
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
409
+ blk_i_msk = ((blk_i_idx >= 0 and
410
+ blk_i_idx < i_b * i_b_s) and
411
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
412
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
413
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
414
+
415
+ # Get indices of sparsity blocks and positions within the blocks
416
+ pos_spa_blk_x = blk_i // sparsity_block_size
417
+ pos_spa_int_x = blk_i % sparsity_block_size
418
+
419
+ rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
420
+ rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
421
+ rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
422
+ dst_row_o = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
423
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
424
+ dst_col_o = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
425
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
426
+ if dim == 0:
427
+ rev_dst_bat_o = blk_i
428
+ elif dim == 1:
429
+ rev_dst_row_o = pos_spa_blk_x
430
+ dst_row_o = pos_spa_int_x * x_r_s
431
+ elif dim == 2:
432
+ rev_dst_col_o = pos_spa_blk_x
433
+ dst_col_o = pos_spa_int_x * x_c_s
434
+
435
+ # Load reverse sparsity indices for o
436
+ rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
437
+ (rev_dst_row_o * s_l_o_r_s) +
438
+ (rev_dst_col_o * s_l_o_c_s))
439
+ rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0 and
440
+ rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s) and
441
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
442
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
443
+ rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
444
+
445
+ # Store output
446
+ blk_o_idx = ((rev_idx_spa_o * o_b_s) +
447
+ dst_row_o +
448
+ dst_col_o)
449
+ blk_o_msk = (((blk_o_idx >= 0 and
450
+ blk_o_idx < o_b * o_b_s) and
451
+ rev_idx_spa_o_msk != -1) and
452
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
453
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
454
+
455
+ if reduce_op_ind == 0:
456
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
457
+ elif reduce_op_ind == 1:
458
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
260
459
 
261
- n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
262
460
 
263
- validate_contiguous(sparsity_layout_src, sparsity_lut_x,
264
- sparsity_layout_tgt, sparsity_reverse_lut_o)
461
+ def scatter_reduce_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
462
+ if lut is None:
463
+ lut = dict()
464
+
465
+ if "sparsity_lut_x" not in lut:
466
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
467
+ lut["sparsity_lut_x"] = sparsity_lut_x
468
+
469
+ if "sparsity_reverse_lut_o" not in lut:
470
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
471
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
472
+ (sparsity_layout_o_flat == 1) -
473
+ (1 * (sparsity_layout_o_flat == 0)))
474
+ lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
475
+
476
+ if "n_sparse_blocks" not in lut:
477
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
478
+ lut["n_sparse_blocks"] = n_sparse_blocks
479
+
480
+ validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
481
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
482
+
483
+ return lut
484
+
485
+
486
+ # noinspection PyUnusedLocal
487
+ def scatter_reduce_setup_context(ctx, inputs, output):
488
+ (_, sparsity_layout_x, _, dim, i, sparsity_layout_o, _, sparsity_block_size, _, reduce_op) = inputs
489
+
490
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
491
+ ctx.dim = dim
492
+ ctx.sparsity_block_size = sparsity_block_size
493
+ ctx.reduce_op = reduce_op
265
494
 
266
- adjusted_dim = dim % 3
267
495
 
268
- return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
269
- adjusted_dim, idx,
270
- sparsity_layout_tgt, sparsity_reverse_lut_o,
271
- sparsity_block_size, n_sparse_blocks,
272
- reduce_op, triton_block_size))
273
-
274
-
275
- class _BlocksparseScatterReduce(torch.autograd.Function):
276
-
277
- @staticmethod
278
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
279
- dim: int, i: Tensor,
280
- sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
281
- sparsity_block_size: int, n_sparse_blocks: int,
282
- reduce_op: str, triton_block_size: int) -> Tensor:
283
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
284
- dtype=x.dtype, device=x.device)
285
-
286
- x_b, x_r, x_c = x.size()
287
- x_b_s, x_r_s, x_c_s = stride(x)
288
- s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
289
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
290
- i_b, i_r, i_c = i.size()
291
- i_b_s, i_r_s, i_c_s = stride(i)
292
- o_b, o_r, o_c = output.size()
293
- o_b_s, o_r_s, o_c_s = stride(output)
294
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
295
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
296
-
297
- if triton_block_size is None:
298
- triton_block_size = get_triton_block_size(sparsity_block_size)
299
-
300
- triton_grid = lambda meta: [x_b,
301
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
302
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
303
-
304
- reduce_op_ind = 0
305
- if reduce_op == "sum":
306
- reduce_op_ind = 1
307
-
308
- (_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]
309
- (x,
310
- x_b, x_b_s, x_r_s, x_c_s,
311
- sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
312
- dim,
313
- i,
314
- i_b, i_b_s, i_r_s, i_c_s,
315
- output,
316
- o_b, o_b_s,
317
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
318
- sparsity_reverse_lut_o,
319
- reduce_op_ind,
320
- sparsity_block_size,
321
- triton_block_size))
322
-
323
- ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
324
- ctx.dim = dim
325
- ctx.sparsity_block_size = sparsity_block_size
326
- ctx.reduce_op = reduce_op
327
- ctx.triton_block_size = triton_block_size
328
-
329
- return output
330
-
331
- @staticmethod
332
- def backward(ctx, grad_output):
333
- sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
334
- dim = ctx.dim
335
- sparsity_block_size = ctx.sparsity_block_size
336
- reduce_op = ctx.reduce_op
337
- triton_block_size = ctx.triton_block_size
338
-
339
- if reduce_op == "sum":
340
- return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x, sparsity_block_size,
341
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
342
- else:
343
- raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
344
-
345
- @staticmethod
346
- @triton.jit
347
- def kernel_blocksparse_scatter(x,
348
- x_b, x_b_s, x_r_s, x_c_s,
349
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
350
- dim,
351
- i,
352
- i_b, i_b_s, i_r_s, i_c_s,
353
- o,
354
- o_b, o_b_s,
355
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
356
- r_lut_o,
357
- reduce_op_ind,
358
- sparsity_block_size,
359
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
360
- # Get triton block indices
361
- pid_blk = tl.program_id(axis=0)
362
- pid_row = tl.program_id(axis=1)
363
- pid_col = tl.program_id(axis=2)
364
-
365
- # Get position of current sparsity block consisting of its batch, row, and column index
366
- spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
367
- spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
368
- spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
369
-
370
- spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
371
- spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
372
- spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
373
-
374
- spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
375
- spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
376
- spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
377
-
378
- # Load x values
379
- blk_x_idx = ((pid_blk * x_b_s) +
380
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
381
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
382
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
383
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
384
-
385
- # Load index values
386
- blk_i_idx = ((pid_blk * i_b_s) +
387
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
388
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
389
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
390
- blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
391
-
392
- # Get indices of sparsity blocks and positions within the blocks
393
- pos_spa_blk_x = blk_i // sparsity_block_size
394
- pos_spa_int_x = blk_i % sparsity_block_size
395
-
396
- rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
397
- rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
398
- rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
399
- dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
400
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
401
- dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
402
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
403
- if dim == 0:
404
- rev_dst_bat_o = blk_i
405
- elif dim == 1:
406
- rev_dst_row_o = pos_spa_blk_x
407
- dst_row_o = pos_spa_int_x * x_r_s
408
- elif dim == 2:
409
- rev_dst_col_o = pos_spa_blk_x
410
- dst_col_o = pos_spa_int_x * x_c_s
411
-
412
- # Load reverse sparsity indices for o
413
- rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
414
- (rev_dst_row_o * s_l_o_r_s) +
415
- (rev_dst_col_o * s_l_o_c_s))
416
- rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
417
- rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
418
-
419
- # Store output
420
- blk_o_idx = ((rev_idx_spa_o * o_b_s) +
421
- dst_row_o +
422
- dst_col_o)
423
- blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_o_msk != -1)
424
-
425
- if reduce_op_ind == 0:
426
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
427
- elif reduce_op_ind == 1:
428
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
496
+ scatter_reduce_forward.register_autograd(scatter_reduce_backward, setup_context=scatter_reduce_setup_context)