blksprs 1.11__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,20 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.ops.conversion import to_dense
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride, get_autotune_configs
9
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
11
+ validate_sparsity, validate_dtype_int, validate_sparsity_block_size
11
12
 
12
13
 
13
14
  def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
14
15
  dim: int,
15
16
  idx: BlksprsTensor, sparsity_layout_idx: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
17
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
17
18
  """Applies a gather operation on a block-sparse tensor in compressed form.
18
19
 
19
20
  Args:
@@ -23,7 +24,6 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
23
24
  idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
24
25
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
25
26
  sparsity_block_size (int): The size of the sparsity blocks.
26
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
27
27
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
28
28
 
29
29
  Returns:
@@ -39,184 +39,203 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
39
39
  validate_device(src, idx)
40
40
  validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
41
41
  validate_sparsity_block_size(sparsity_block_size, src, idx)
42
- validate_triton_block_size(triton_block_size, sparsity_block_size)
43
42
 
44
43
  adjusted_dim = dim % 3
45
44
 
46
- lut = _BlocksparseGather.build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
47
-
48
- return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
49
- adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
50
- sparsity_block_size, triton_block_size))
51
-
52
-
53
- class _BlocksparseGather(torch.autograd.Function):
54
-
55
- @staticmethod
56
- def build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
57
- if lut is None:
58
- lut = dict()
59
-
60
- if "sparsity_reverse_lut_x" not in lut:
61
- sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
62
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
63
- (sparsity_layout_x_flat == 1) -
64
- (1 * (sparsity_layout_x_flat == 0)))
65
- lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
66
-
67
- if "sparsity_lut_i" not in lut:
68
- sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
69
- lut["sparsity_lut_i"] = sparsity_lut_i
70
-
71
- validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
72
- sparsity_layout_idx, lut["sparsity_lut_i"])
73
-
74
- return lut
75
-
76
- @staticmethod
77
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
78
- dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
79
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
80
- output = torch.empty_like(i, dtype=x.dtype)
81
-
82
- x_b, x_r, x_c = x.size()
83
- x_b_s, x_r_s, x_c_s = stride(x)
84
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
85
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
86
- i_b, i_r, i_c = i.size()
87
- i_b_s, i_r_s, i_c_s = stride(i)
88
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
89
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
90
- o_b, o_r, o_c = output.size()
91
- o_b_s, o_r_s, o_c_s = stride(output)
92
-
93
- if triton_block_size is None:
94
- triton_block_size = get_triton_block_size(sparsity_block_size)
95
-
96
- triton_grid = lambda meta: [o_b,
97
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
98
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
99
-
100
- (_BlocksparseGather.kernel_blocksparse_gather[triton_grid]
101
- (x,
102
- x_b, x_b_s, x_r_s, x_c_s,
103
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
104
- sparsity_reverse_lut_x,
105
- dim,
106
- i,
107
- i_b, i_b_s, i_r_s, i_c_s,
108
- output,
109
- o_b, o_b_s, o_r_s, o_c_s,
110
- sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
111
- sparsity_block_size,
112
- triton_block_size))
113
-
114
- ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
115
- ctx.dim = dim
116
- ctx.sparsity_block_size = sparsity_block_size
117
- ctx.triton_block_size = triton_block_size
118
-
119
- return output
120
-
121
- @staticmethod
122
- def backward(ctx, grad_output):
123
- sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
124
- dim = ctx.dim
125
- sparsity_block_size = ctx.sparsity_block_size
126
- triton_block_size = ctx.triton_block_size
127
-
128
- return scatter_reduce(grad_output, sparsity_layout_i,
129
- dim, i,
130
- sparsity_layout_x, sparsity_block_size,
131
- reduce_op="sum",
132
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
133
-
134
- @staticmethod
135
- @triton.jit
136
- def kernel_blocksparse_gather(x,
137
- x_b, x_b_s, x_r_s, x_c_s,
138
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
139
- r_lut_x,
140
- dim,
141
- i,
142
- i_b, i_b_s, i_r_s, i_c_s,
143
- o,
144
- o_b, o_b_s, o_r_s, o_c_s,
145
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
146
- sparsity_block_size,
147
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
148
- # Get triton block indices
149
- pid_blk = tl.program_id(axis=0)
150
- pid_row = tl.program_id(axis=1)
151
- pid_col = tl.program_id(axis=2)
152
-
153
- # Get position of current sparsity block consisting of its batch, row, and column index
154
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
155
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
156
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
157
-
158
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
159
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
160
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
161
-
162
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
163
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
164
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
165
-
166
- # Load index values
167
- blk_i_idx = ((pid_blk * i_b_s) +
168
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
169
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
170
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
171
- blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
172
-
173
- # Get indices of sparsity blocks and positions within the blocks
174
- pos_spa_blk_x = blk_i // sparsity_block_size
175
- pos_spa_int_x = blk_i % sparsity_block_size
176
-
177
- rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
178
- rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
179
- rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
180
- dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
181
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
182
- dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
183
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
184
- if dim == 0:
185
- rev_dst_bat_x = blk_i
186
- elif dim == 1:
187
- rev_dst_row_x = pos_spa_blk_x
188
- dst_row_x = pos_spa_int_x * x_r_s
189
- elif dim == 2:
190
- rev_dst_col_x = pos_spa_blk_x
191
- dst_col_x = pos_spa_int_x * x_c_s
192
-
193
- # Load reverse sparsity indices for x
194
- rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
195
- (rev_dst_row_x * s_l_x_r_s) +
196
- (rev_dst_col_x * s_l_x_c_s))
197
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
198
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
199
-
200
- # Load x values
201
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
202
- dst_row_x +
203
- dst_col_x)
204
- blk_x_msk = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s) and rev_idx_spa_x_msk != -1)
205
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
206
-
207
- # Store output
208
- blk_o_idx = ((pid_blk * o_b_s) +
209
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
210
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
211
- blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_x_msk != -1)
212
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
45
+ lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
46
+
47
+ return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
48
+ adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
49
+ sparsity_block_size))
50
+
51
+
52
+ @triton_op("blksprs::gather", mutates_args={})
53
+ def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
54
+ dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
55
+ sparsity_block_size: int) -> Tensor:
56
+ output = torch.empty_like(i, dtype=x.dtype)
57
+
58
+ x_b, x_r, x_c = x.size()
59
+ x_b_s, x_r_s, x_c_s = stride(x)
60
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
61
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
62
+ i_b, i_r, i_c = i.size()
63
+ i_b_s, i_r_s, i_c_s = stride(i)
64
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
65
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
66
+ o_b, o_r, o_c = output.size()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
68
+
69
+ triton_grid = lambda meta: [o_b,
70
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
71
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
72
+
73
+ (wrap_triton(gather_kernel)[triton_grid]
74
+ (x,
75
+ x_b, x_b_s, x_r_s, x_c_s,
76
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
77
+ sparsity_reverse_lut_x,
78
+ dim,
79
+ i,
80
+ i_b, i_b_s, i_r_s, i_c_s,
81
+ output,
82
+ o_b, o_b_s, o_r_s, o_c_s,
83
+ sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
84
+ sparsity_block_size))
85
+
86
+ return output
87
+
88
+
89
+ def gather_backward(ctx, grad_output):
90
+ sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
91
+ dim = ctx.dim
92
+ sparsity_block_size = ctx.sparsity_block_size
93
+
94
+ return scatter_reduce(grad_output, sparsity_layout_i,
95
+ dim, i,
96
+ sparsity_layout_x, sparsity_block_size,
97
+ reduce_op="sum"), None, None, None, None, None, None, None
98
+
99
+
100
+ @triton.autotune(
101
+ configs=get_autotune_configs(),
102
+ key=[],
103
+ )
104
+ @triton.jit
105
+ def gather_kernel(x,
106
+ x_b, x_b_s, x_r_s, x_c_s,
107
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
108
+ r_lut_x,
109
+ dim,
110
+ i,
111
+ i_b, i_b_s, i_r_s, i_c_s,
112
+ o,
113
+ o_b, o_b_s, o_r_s, o_c_s,
114
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
115
+ sparsity_block_size,
116
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
117
+ # Get triton block indices
118
+ pid_blk = tl.program_id(axis=0)
119
+ pid_row = tl.program_id(axis=1)
120
+ pid_col = tl.program_id(axis=2)
121
+
122
+ # Get valid triton block size
123
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
124
+
125
+ # Get position of current sparsity block consisting of its batch, row, and column index
126
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
127
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
128
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
129
+
130
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
131
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
132
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
133
+
134
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
135
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
136
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
137
+
138
+ # Load index values
139
+ blk_i_idx = ((pid_blk * i_b_s) +
140
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
141
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
142
+ blk_i_msk = ((blk_i_idx >= 0 and
143
+ blk_i_idx < i_b * i_b_s) and
144
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
145
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
146
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
147
+
148
+ # Get indices of sparsity blocks and positions within the blocks
149
+ pos_spa_blk_x = blk_i // sparsity_block_size
150
+ pos_spa_int_x = blk_i % sparsity_block_size
151
+
152
+ rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
153
+ rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
154
+ rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
155
+ dst_row_x = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
156
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
157
+ dst_col_x = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
158
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
159
+ if dim == 0:
160
+ rev_dst_bat_x = blk_i
161
+ elif dim == 1:
162
+ rev_dst_row_x = pos_spa_blk_x
163
+ dst_row_x = pos_spa_int_x * x_r_s
164
+ elif dim == 2:
165
+ rev_dst_col_x = pos_spa_blk_x
166
+ dst_col_x = pos_spa_int_x * x_c_s
167
+
168
+ # Load reverse sparsity indices for x
169
+ rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
170
+ (rev_dst_row_x * s_l_x_r_s) +
171
+ (rev_dst_col_x * s_l_x_c_s))
172
+ rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0 and
173
+ rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s) and
174
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
175
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
176
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
177
+
178
+ # Load x values
179
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
180
+ dst_row_x +
181
+ dst_col_x)
182
+ blk_x_msk = (((blk_x_idx >= 0 and
183
+ blk_x_idx < x_b * x_b_s) and
184
+ rev_idx_spa_x_msk != -1) and
185
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
186
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
187
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
188
+
189
+ # Store output
190
+ blk_o_idx = ((pid_blk * o_b_s) +
191
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
192
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
193
+ blk_o_msk = (((blk_o_idx >= 0 and
194
+ blk_o_idx < o_b * o_b_s) and
195
+ rev_idx_spa_x_msk != -1) and
196
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
197
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
198
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
199
+
200
+
201
+ def gather_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
202
+ if lut is None:
203
+ lut = dict()
204
+
205
+ if "sparsity_reverse_lut_x" not in lut:
206
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
207
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
208
+ (sparsity_layout_x_flat == 1) -
209
+ (1 * (sparsity_layout_x_flat == 0)))
210
+ lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
211
+
212
+ if "sparsity_lut_i" not in lut:
213
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
214
+ lut["sparsity_lut_i"] = sparsity_lut_i
215
+
216
+ validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
217
+ sparsity_layout_idx, lut["sparsity_lut_i"])
218
+
219
+ return lut
220
+
221
+
222
+ # noinspection PyUnusedLocal
223
+ def gather_setup_context(ctx, inputs, output):
224
+ (_, sparsity_layout_x, _, dim, i, sparsity_layout_i, _, sparsity_block_size) = inputs
225
+
226
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
227
+ ctx.dim = dim
228
+ ctx.sparsity_block_size = sparsity_block_size
229
+
230
+
231
+ gather_forward.register_autograd(gather_backward, setup_context=gather_setup_context)
213
232
 
214
233
 
215
234
  def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
216
235
  dim: int,
217
236
  idx: BlksprsTensor,
218
237
  sparsity_layout_tgt: Tensor,
219
- sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
238
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
220
239
  """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
221
240
 
222
241
  """
@@ -225,7 +244,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
225
244
  idx,
226
245
  sparsity_layout_tgt,
227
246
  sparsity_block_size,
228
- reduce_op="none", triton_block_size=triton_block_size)
247
+ reduce_op="none", lut=lut)
229
248
 
230
249
 
231
250
  def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
@@ -233,7 +252,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
233
252
  idx: BlksprsTensor,
234
253
  sparsity_layout_tgt: Tensor,
235
254
  sparsity_block_size: int,
236
- reduce_op: str = "sum", triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
255
+ reduce_op: str = "sum", lut: dict = None) -> BlksprsTensor:
237
256
  """Applies a scatter operation on a block-sparse tensor in compressed form.
238
257
 
239
258
  Args:
@@ -245,7 +264,6 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
245
264
  sparsity_block_size (int): The size of the sparsity blocks.
246
265
  reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
247
266
  Supported operations are ``"none"`` and ``"sum"``.
248
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
249
267
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
250
268
 
251
269
  Returns:
@@ -261,198 +279,218 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
261
279
  validate_device(src, idx)
262
280
  validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
263
281
  validate_sparsity_block_size(sparsity_block_size, src, idx)
264
- validate_triton_block_size(triton_block_size, sparsity_block_size)
265
282
 
266
283
  if reduce_op not in ["none", "sum"]:
267
284
  raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
268
285
 
269
286
  adjusted_dim = dim % 3
270
287
 
271
- lut = _BlocksparseScatterReduce.build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
272
-
273
- return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, lut["sparsity_lut_x"],
274
- adjusted_dim, idx,
275
- sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
276
- sparsity_block_size, lut["n_sparse_blocks"],
277
- reduce_op, triton_block_size))
278
-
279
-
280
- class _BlocksparseScatterReduce(torch.autograd.Function):
281
-
282
- @staticmethod
283
- def build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
284
- if lut is None:
285
- lut = dict()
286
-
287
- if "sparsity_lut_x" not in lut:
288
- sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
289
- lut["sparsity_lut_x"] = sparsity_lut_x
290
-
291
- if "sparsity_reverse_lut_o" not in lut:
292
- sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
293
- sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
294
- (sparsity_layout_o_flat == 1) -
295
- (1 * (sparsity_layout_o_flat == 0)))
296
- lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
297
-
298
- if "n_sparse_blocks" not in lut:
299
- n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
300
- lut["n_sparse_blocks"] = n_sparse_blocks
301
-
302
- validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
303
- sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
304
-
305
- return lut
306
-
307
- @staticmethod
308
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
309
- dim: int, i: Tensor,
310
- sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
311
- sparsity_block_size: int, n_sparse_blocks: int,
312
- reduce_op: str, triton_block_size: int) -> Tensor:
313
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
314
- dtype=x.dtype, device=x.device)
315
-
316
- x_b, x_r, x_c = x.size()
317
- x_b_s, x_r_s, x_c_s = stride(x)
318
- s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
319
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
320
- i_b, i_r, i_c = i.size()
321
- i_b_s, i_r_s, i_c_s = stride(i)
322
- o_b, o_r, o_c = output.size()
323
- o_b_s, o_r_s, o_c_s = stride(output)
324
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
325
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
326
-
327
- if triton_block_size is None:
328
- triton_block_size = get_triton_block_size(sparsity_block_size)
329
-
330
- triton_grid = lambda meta: [x_b,
331
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
332
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
333
-
334
- reduce_op_ind = 0
335
- if reduce_op == "sum":
336
- reduce_op_ind = 1
337
-
338
- (_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]
339
- (x,
340
- x_b, x_b_s, x_r_s, x_c_s,
341
- sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
342
- dim,
343
- i,
344
- i_b, i_b_s, i_r_s, i_c_s,
345
- output,
346
- o_b, o_b_s,
347
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
348
- sparsity_reverse_lut_o,
349
- reduce_op_ind,
350
- sparsity_block_size,
351
- triton_block_size))
352
-
353
- ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
354
- ctx.dim = dim
355
- ctx.sparsity_block_size = sparsity_block_size
356
- ctx.reduce_op = reduce_op
357
- ctx.triton_block_size = triton_block_size
358
-
359
- return output
360
-
361
- @staticmethod
362
- def backward(ctx, grad_output):
363
- sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
364
- dim = ctx.dim
365
- sparsity_block_size = ctx.sparsity_block_size
366
- reduce_op = ctx.reduce_op
367
- triton_block_size = ctx.triton_block_size
368
-
369
- if reduce_op == "sum":
370
- return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x, sparsity_block_size,
371
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
372
- else:
373
- raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
374
-
375
- @staticmethod
376
- @triton.jit
377
- def kernel_blocksparse_scatter(x,
378
- x_b, x_b_s, x_r_s, x_c_s,
379
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
380
- dim,
381
- i,
382
- i_b, i_b_s, i_r_s, i_c_s,
383
- o,
384
- o_b, o_b_s,
385
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
386
- r_lut_o,
387
- reduce_op_ind,
388
- sparsity_block_size,
389
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
390
- # Get triton block indices
391
- pid_blk = tl.program_id(axis=0)
392
- pid_row = tl.program_id(axis=1)
393
- pid_col = tl.program_id(axis=2)
394
-
395
- # Get position of current sparsity block consisting of its batch, row, and column index
396
- spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
397
- spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
398
- spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
399
-
400
- spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
401
- spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
402
- spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
403
-
404
- spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
405
- spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
406
- spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
407
-
408
- # Load x values
409
- blk_x_idx = ((pid_blk * x_b_s) +
410
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
411
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
412
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
413
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
414
-
415
- # Load index values
416
- blk_i_idx = ((pid_blk * i_b_s) +
417
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
418
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
419
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
420
- blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
421
-
422
- # Get indices of sparsity blocks and positions within the blocks
423
- pos_spa_blk_x = blk_i // sparsity_block_size
424
- pos_spa_int_x = blk_i % sparsity_block_size
425
-
426
- rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
427
- rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
428
- rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
429
- dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
430
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
431
- dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
432
- .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
433
- if dim == 0:
434
- rev_dst_bat_o = blk_i
435
- elif dim == 1:
436
- rev_dst_row_o = pos_spa_blk_x
437
- dst_row_o = pos_spa_int_x * x_r_s
438
- elif dim == 2:
439
- rev_dst_col_o = pos_spa_blk_x
440
- dst_col_o = pos_spa_int_x * x_c_s
441
-
442
- # Load reverse sparsity indices for o
443
- rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
444
- (rev_dst_row_o * s_l_o_r_s) +
445
- (rev_dst_col_o * s_l_o_c_s))
446
- rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
447
- rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
448
-
449
- # Store output
450
- blk_o_idx = ((rev_idx_spa_o * o_b_s) +
451
- dst_row_o +
452
- dst_col_o)
453
- blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_o_msk != -1)
454
-
455
- if reduce_op_ind == 0:
456
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
457
- elif reduce_op_ind == 1:
458
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
288
+ lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
289
+
290
+ return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
291
+ adjusted_dim, idx,
292
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
293
+ sparsity_block_size, lut["n_sparse_blocks"],
294
+ reduce_op))
295
+
296
+
297
+ @triton_op("blksprs::scatter_reduce", mutates_args={})
298
+ def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
299
+ dim: int, i: Tensor,
300
+ sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
301
+ sparsity_block_size: int, n_sparse_blocks: int,
302
+ reduce_op: str) -> Tensor:
303
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
304
+ dtype=x.dtype, device=x.device)
305
+
306
+ x_b, x_r, x_c = x.size()
307
+ x_b_s, x_r_s, x_c_s = stride(x)
308
+ s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
309
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
310
+ i_b, i_r, i_c = i.size()
311
+ i_b_s, i_r_s, i_c_s = stride(i)
312
+ o_b, o_r, o_c = output.size()
313
+ o_b_s, o_r_s, o_c_s = stride(output)
314
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
315
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
316
+
317
+ triton_grid = lambda meta: [x_b,
318
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
319
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
320
+
321
+ reduce_op_ind = 0
322
+ if reduce_op == "sum":
323
+ reduce_op_ind = 1
324
+
325
+ (wrap_triton(scatter_reduce_kernel)[triton_grid]
326
+ (x,
327
+ x_b, x_b_s, x_r_s, x_c_s,
328
+ sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
329
+ dim,
330
+ i,
331
+ i_b, i_b_s, i_r_s, i_c_s,
332
+ output,
333
+ o_b, o_b_s,
334
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
335
+ sparsity_reverse_lut_o,
336
+ reduce_op_ind,
337
+ sparsity_block_size))
338
+
339
+ return output
340
+
341
+
342
+ def scatter_reduce_backward(ctx, grad_output):
343
+ sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
344
+ dim = ctx.dim
345
+ sparsity_block_size = ctx.sparsity_block_size
346
+ reduce_op = ctx.reduce_op
347
+
348
+ if reduce_op == "sum":
349
+ return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x,
350
+ sparsity_block_size), None, None, None, None, None, None, None, None, None
351
+ else:
352
+ raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
353
+
354
+
355
+ @triton.autotune(
356
+ configs=get_autotune_configs(),
357
+ key=[],
358
+ reset_to_zero=["o"]
359
+ )
360
+ @triton.jit
361
+ def scatter_reduce_kernel(x,
362
+ x_b, x_b_s, x_r_s, x_c_s,
363
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
364
+ dim,
365
+ i,
366
+ i_b, i_b_s, i_r_s, i_c_s,
367
+ o,
368
+ o_b, o_b_s,
369
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
370
+ r_lut_o,
371
+ reduce_op_ind,
372
+ sparsity_block_size,
373
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
374
+ # Get triton block indices
375
+ pid_blk = tl.program_id(axis=0)
376
+ pid_row = tl.program_id(axis=1)
377
+ pid_col = tl.program_id(axis=2)
378
+
379
+ # Get valid triton block size
380
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
381
+
382
+ # Get position of current sparsity block consisting of its batch, row, and column index
383
+ spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
384
+ spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
385
+ spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
386
+
387
+ spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
388
+ spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
389
+ spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
390
+
391
+ spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
392
+ spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
393
+ spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
394
+
395
+ # Load x values
396
+ blk_x_idx = ((pid_blk * x_b_s) +
397
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
398
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
399
+ blk_x_msk = ((blk_x_idx >= 0 and
400
+ blk_x_idx < x_b * x_b_s) and
401
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
402
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
403
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
404
+
405
+ # Load index values
406
+ blk_i_idx = ((pid_blk * i_b_s) +
407
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
408
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
409
+ blk_i_msk = ((blk_i_idx >= 0 and
410
+ blk_i_idx < i_b * i_b_s) and
411
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
412
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
413
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
414
+
415
+ # Get indices of sparsity blocks and positions within the blocks
416
+ pos_spa_blk_x = blk_i // sparsity_block_size
417
+ pos_spa_int_x = blk_i % sparsity_block_size
418
+
419
+ rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
420
+ rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
421
+ rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
422
+ dst_row_o = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
423
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
424
+ dst_col_o = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
425
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
426
+ if dim == 0:
427
+ rev_dst_bat_o = blk_i
428
+ elif dim == 1:
429
+ rev_dst_row_o = pos_spa_blk_x
430
+ dst_row_o = pos_spa_int_x * x_r_s
431
+ elif dim == 2:
432
+ rev_dst_col_o = pos_spa_blk_x
433
+ dst_col_o = pos_spa_int_x * x_c_s
434
+
435
+ # Load reverse sparsity indices for o
436
+ rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
437
+ (rev_dst_row_o * s_l_o_r_s) +
438
+ (rev_dst_col_o * s_l_o_c_s))
439
+ rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0 and
440
+ rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s) and
441
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
442
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
443
+ rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
444
+
445
+ # Store output
446
+ blk_o_idx = ((rev_idx_spa_o * o_b_s) +
447
+ dst_row_o +
448
+ dst_col_o)
449
+ blk_o_msk = (((blk_o_idx >= 0 and
450
+ blk_o_idx < o_b * o_b_s) and
451
+ rev_idx_spa_o_msk != -1) and
452
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
453
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
454
+
455
+ if reduce_op_ind == 0:
456
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
457
+ elif reduce_op_ind == 1:
458
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
459
+
460
+
461
+ def scatter_reduce_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
462
+ if lut is None:
463
+ lut = dict()
464
+
465
+ if "sparsity_lut_x" not in lut:
466
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
467
+ lut["sparsity_lut_x"] = sparsity_lut_x
468
+
469
+ if "sparsity_reverse_lut_o" not in lut:
470
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
471
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
472
+ (sparsity_layout_o_flat == 1) -
473
+ (1 * (sparsity_layout_o_flat == 0)))
474
+ lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
475
+
476
+ if "n_sparse_blocks" not in lut:
477
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
478
+ lut["n_sparse_blocks"] = n_sparse_blocks
479
+
480
+ validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
481
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
482
+
483
+ return lut
484
+
485
+
486
+ # noinspection PyUnusedLocal
487
+ def scatter_reduce_setup_context(ctx, inputs, output):
488
+ (_, sparsity_layout_x, _, dim, i, sparsity_layout_o, _, sparsity_block_size, _, reduce_op) = inputs
489
+
490
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
491
+ ctx.dim = dim
492
+ ctx.sparsity_block_size = sparsity_block_size
493
+ ctx.reduce_op = reduce_op
494
+
495
+
496
+ scatter_reduce_forward.register_autograd(scatter_reduce_backward, setup_context=scatter_reduce_setup_context)