blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,19 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library.triton import wrap_triton, triton_op
4
5
  from triton import language as tl
5
6
 
7
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
6
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride
8
10
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
9
- validate_sparsity_block_size, validate_triton_block_size
11
+ validate_sparsity_block_size
10
12
 
11
13
 
14
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
12
15
  def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
13
- flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
16
+ flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
14
17
  """Computes the row-wise sum of a block-sparse tensor.
15
18
 
16
19
  Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
@@ -25,7 +28,6 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
25
28
  sparsity_block_size (int): The size of the sparsity blocks.
26
29
  flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
27
30
  (default ``False``).
28
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
29
31
 
30
32
  Returns:
31
33
  tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
@@ -39,7 +41,6 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
39
41
  validate_device(x)
40
42
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
41
43
  validate_sparsity_block_size(sparsity_block_size, x)
42
- validate_triton_block_size(triton_block_size, sparsity_block_size)
43
44
 
44
45
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
45
46
 
@@ -54,50 +55,65 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
54
55
  validate_contiguous(sparsity_layout, sparsity_lut,
55
56
  sparsity_layout_output, sparsity_reverse_lut_output)
56
57
 
57
- output = torch.zeros(size=(n_sparse_blocks_output,
58
- sparsity_block_size,
59
- 1 if flag_slice_only else sparsity_block_size),
60
- dtype=x.dtype,
61
- device=x.device)
62
-
63
- x_b, x_r, x_c = x.size()
64
- x_b_s, x_r_s, x_c_s = stride(x)
65
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
66
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
67
- o_b, o_r, o_c = output.size()
68
- o_b_s, o_r_s, o_c_s = stride(output)
69
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
70
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
71
-
72
- if triton_block_size is None:
73
- triton_block_size = get_triton_block_size(sparsity_block_size)
74
-
75
- triton_grid = lambda meta: [x_b,
76
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
77
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
78
-
79
- (kernel_blocksparse_row_wise_sum[triton_grid]
80
- (x,
81
- x_b, x_b_s, x_r_s, x_c_s,
82
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
83
- output,
84
- o_b, o_b_s, o_r_s,
85
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
86
- sparsity_reverse_lut_output,
87
- triton_block_size))
88
-
89
- return BlksprsTensor(output), sparsity_layout_output
90
-
91
-
58
+ return BlksprsTensor(row_wise_sum_forward(
59
+ x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output,
60
+ sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
61
+
62
+
63
+ @triton_op("blksprs::row_wise_sum_forward", mutates_args={})
64
+ def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
65
+ sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
66
+ sparsity_block_size: int, n_sparse_blocks_output: int,
67
+ flag_slice_only: bool = False) -> Tensor:
68
+ with torch.no_grad():
69
+ output = torch.zeros(
70
+ size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
71
+ dtype=x.dtype, device=x.device)
72
+
73
+ x_b, x_r, x_c = x.size()
74
+ x_b_s, x_r_s, x_c_s = stride(x)
75
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
76
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
77
+ o_b, o_r, o_c = output.size()
78
+ o_b_s, o_r_s, o_c_s = stride(output)
79
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
80
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
81
+
82
+ triton_grid = lambda meta: [x_b,
83
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
84
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
85
+
86
+ (wrap_triton(row_wise_sum_kernel)[triton_grid]
87
+ (x,
88
+ x_b, x_b_s, x_r_s, x_c_s,
89
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
90
+ output,
91
+ o_b, o_b_s, o_r_s,
92
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
93
+ sparsity_reverse_lut_output,
94
+ sparsity_block_size))
95
+
96
+ return output
97
+
98
+
99
+ # noinspection PyUnusedLocal
100
+ @triton.autotune(
101
+ configs=get_autotune_configs(),
102
+ key=["sparsity_block_size"],
103
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
104
+ reset_to_zero=["o"]
105
+ )
92
106
  @triton.jit
93
- def kernel_blocksparse_row_wise_sum(x,
94
- x_b, x_b_s, x_r_s, x_c_s,
95
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
96
- o,
97
- o_b, o_b_s, o_r_s,
98
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
99
- r_lut_o,
100
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
107
+ def row_wise_sum_kernel(x,
108
+ x_b, x_b_s, x_r_s, x_c_s,
109
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
110
+ o,
111
+ o_b, o_b_s, o_r_s,
112
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
113
+ r_lut_o,
114
+ sparsity_block_size,
115
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
116
+ # Get triton block indices
101
117
  pid_blk = tl.program_id(axis=0)
102
118
  pid_row = tl.program_id(axis=1)
103
119
  pid_col = tl.program_id(axis=2)
@@ -117,27 +133,27 @@ def kernel_blocksparse_row_wise_sum(x,
117
133
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
118
134
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
119
135
 
120
- if rev_idx_spa == -1:
121
- tl.device_assert(False)
122
- return
123
-
124
- blk_idx = ((pid_blk * x_b_s) +
125
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
126
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
127
- blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
128
- blk = tl.load(x + blk_idx, mask=blk_msk)
136
+ if rev_idx_spa >= 0:
137
+ blk_idx = ((pid_blk * x_b_s) +
138
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
139
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
140
+ blk_msk = (blk_idx >= 0 and
141
+ blk_idx < x_b * x_b_s)
142
+ blk = tl.load(x + blk_idx, mask=blk_msk)
129
143
 
130
- buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
144
+ buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
131
145
 
132
- o_idx = (rev_idx_spa * o_b_s +
133
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
- (tl.arange(0, 1))[None, :])
135
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
136
- tl.atomic_add(o + o_idx, buf, o_msk)
146
+ o_idx = (rev_idx_spa * o_b_s +
147
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
+ (tl.arange(0, 1))[None, :])
149
+ o_msk = (o_idx >= 0 and
150
+ o_idx < o_b * o_b_s)
151
+ tl.atomic_add(o + o_idx, buf, o_msk)
137
152
 
138
153
 
154
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
139
155
  def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
140
- flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
156
+ flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
141
157
  """Computes the row-wise max of a block-sparse tensor.
142
158
 
143
159
  Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
@@ -152,13 +168,14 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
152
168
  sparsity_block_size (int): The size of the sparsity blocks.
153
169
  flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
154
170
  (default ``False``).
155
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
156
171
 
157
172
  Returns:
158
173
  tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
159
174
  of the input and the sparsity layout of the output tensor.
160
175
 
161
176
  """
177
+ # TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
178
+ x = torch.where(x == -0.0, torch.tensor(0.0), x)
162
179
  x = x.contiguous()
163
180
 
164
181
  validate_dimensions(x)
@@ -166,7 +183,6 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
166
183
  validate_device(x)
167
184
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
168
185
  validate_sparsity_block_size(sparsity_block_size, x)
169
- validate_triton_block_size(triton_block_size, sparsity_block_size)
170
186
 
171
187
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
172
188
 
@@ -181,50 +197,67 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
181
197
  validate_contiguous(sparsity_layout, sparsity_lut,
182
198
  sparsity_layout_output, sparsity_reverse_lut_output)
183
199
 
184
- output = torch.full(size=(n_sparse_blocks_output,
185
- sparsity_block_size,
186
- 1 if flag_slice_only else sparsity_block_size),
187
- fill_value=float("-inf"),
188
- device=x.device)
189
-
190
- x_b, x_r, x_c = x.size()
191
- x_b_s, x_r_s, x_c_s = stride(x)
192
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
193
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
194
- o_b, o_r, o_c = output.size()
195
- o_b_s, o_r_s, o_c_s = stride(output)
196
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
197
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
198
-
199
- if triton_block_size is None:
200
- triton_block_size = get_triton_block_size(sparsity_block_size)
201
-
202
- triton_grid = lambda meta: [x_b,
203
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
204
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
205
-
206
- (kernel_blocksparse_row_wise_max[triton_grid]
207
- (x,
208
- x_b, x_b_s, x_r_s, x_c_s,
209
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
210
- output,
211
- o_b, o_b_s, o_r_s,
212
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
213
- sparsity_reverse_lut_output,
214
- triton_block_size))
215
-
216
- return BlksprsTensor(output), sparsity_layout_output
217
-
218
-
200
+ return BlksprsTensor(
201
+ row_wise_max_forward(x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output, sparsity_block_size,
202
+ n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
203
+
204
+
205
+ @triton_op("blksprs::row_wise_max_forward", mutates_args={})
206
+ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
207
+ sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
208
+ sparsity_block_size: int, n_sparse_blocks_output: int,
209
+ flag_slice_only: bool = False) -> Tensor:
210
+ with torch.no_grad():
211
+ output = torch.full(size=(n_sparse_blocks_output,
212
+ sparsity_block_size,
213
+ 1 if flag_slice_only else sparsity_block_size),
214
+ fill_value=torch.finfo(x.dtype).min,
215
+ device=x.device)
216
+
217
+ x_b, x_r, x_c = x.size()
218
+ x_b_s, x_r_s, x_c_s = stride(x)
219
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
220
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
221
+ o_b, o_r, o_c = output.size()
222
+ o_b_s, o_r_s, o_c_s = stride(output)
223
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
224
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
225
+
226
+ triton_grid = lambda meta: [x_b,
227
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
228
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
229
+
230
+ (wrap_triton(row_wise_max_kernel)[triton_grid]
231
+ (x,
232
+ x_b, x_b_s, x_r_s, x_c_s,
233
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
234
+ output,
235
+ o_b, o_b_s, o_r_s,
236
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
237
+ sparsity_reverse_lut_output,
238
+ sparsity_block_size))
239
+
240
+ return output
241
+
242
+
243
+ # noinspection PyUnusedLocal
244
+ @triton.autotune(
245
+ configs=get_autotune_configs(),
246
+ key=["sparsity_block_size"],
247
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
248
+ restore_value=["o"]
249
+ )
219
250
  @triton.jit
220
- def kernel_blocksparse_row_wise_max(x,
221
- x_b, x_b_s, x_r_s, x_c_s,
222
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
223
- o,
224
- o_b, o_b_s, o_r_s,
225
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
226
- r_lut_o,
227
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
251
+ def row_wise_max_kernel(x,
252
+ x_b, x_b_s, x_r_s, x_c_s,
253
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
254
+ o,
255
+ o_b, o_b_s, o_r_s,
256
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
257
+ r_lut_o,
258
+ sparsity_block_size,
259
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
260
+ # Get triton block indices
228
261
  pid_blk = tl.program_id(axis=0)
229
262
  pid_row = tl.program_id(axis=1)
230
263
  pid_col = tl.program_id(axis=2)
@@ -244,27 +277,27 @@ def kernel_blocksparse_row_wise_max(x,
244
277
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
245
278
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
246
279
 
247
- if rev_idx_spa == -1:
248
- tl.device_assert(False)
249
- return
250
-
251
- blk_idx = ((pid_blk * x_b_s) +
252
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
253
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
254
- blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
255
- blk = tl.load(x + blk_idx, mask=blk_msk)
280
+ if rev_idx_spa >= 0:
281
+ blk_idx = ((pid_blk * x_b_s) +
282
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
283
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
286
+ blk = tl.load(x + blk_idx, mask=blk_msk)
256
287
 
257
- buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
288
+ buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
258
289
 
259
- o_idx = (rev_idx_spa * o_b_s +
260
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
261
- (tl.arange(0, 1))[None, :])
262
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
263
- tl.atomic_max(o + o_idx, buf, o_msk)
290
+ o_idx = (rev_idx_spa * o_b_s +
291
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
292
+ (tl.arange(0, 1))[None, :])
293
+ o_msk = (o_idx >= 0 and
294
+ o_idx < o_b * o_b_s)
295
+ tl.atomic_max(o + o_idx, buf, o_msk)
264
296
 
265
297
 
298
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
266
299
  def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
267
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
300
+ sparsity_block_size: int) -> BlksprsTensor:
268
301
  """For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
269
302
 
270
303
  Args:
@@ -272,7 +305,6 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
272
305
  sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
273
306
  y (BlksprsTensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
274
307
  sparsity_block_size (int): The size of the sparsity blocks.
275
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
276
308
 
277
309
  Returns:
278
310
  BlksprsTensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
@@ -284,9 +316,8 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
284
316
  validate_device(x)
285
317
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
286
318
  validate_sparsity_block_size(sparsity_block_size, x)
287
- validate_triton_block_size(triton_block_size, sparsity_block_size)
288
319
 
289
- sparsity_lut = torch.nonzero(sparsity_layout_x).contiguous()
320
+ sparsity_lut_x = torch.nonzero(sparsity_layout_x).contiguous()
290
321
 
291
322
  sparsity_layout_rwm, _ = torch.max(sparsity_layout_x, dim=-1, keepdim=True)
292
323
  sparsity_layout_rwm_flat = sparsity_layout_rwm.reshape(-1)
@@ -294,60 +325,73 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
294
325
  (sparsity_layout_rwm_flat == 1) -
295
326
  (1 * (sparsity_layout_rwm_flat == 0)))
296
327
 
297
- validate_contiguous(sparsity_layout_x, sparsity_lut, sparsity_reverse_lut_rwm)
298
-
299
- output = torch.empty_like(x)
300
-
301
- x_b, x_r, x_c = x.size()
302
- x_b_s, x_r_s, x_c_s = stride(x)
303
- s_lut_r, s_lut_c = sparsity_lut.size()
304
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
305
- y_b, y_r, y_c = y.size()
306
- y_b_s, y_r_s, y_c_s = stride(y)
307
- s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
308
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
309
- o_b, o_r, o_c = output.size()
310
- o_b_s, o_r_s, o_c_s = stride(output)
311
-
312
- if triton_block_size is None:
313
- triton_block_size = get_triton_block_size(sparsity_block_size)
314
-
315
- triton_grid = lambda meta: [o_b,
316
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
317
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
318
-
319
- (kernel_blocksparse_row_wise_add[triton_grid]
320
- (x,
321
- x_b, x_b_s, x_r_s, x_c_s,
322
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
323
- y, y_b, y_b_s, y_r_s, y_c_s,
324
- s_l_y_b, s_l_y_b_s, s_l_y_r_s,
325
- sparsity_reverse_lut_rwm,
326
- output,
327
- o_b, o_b_s, o_r_s, o_c_s,
328
- triton_block_size
329
- ))
328
+ validate_contiguous(sparsity_layout_x, sparsity_lut_x, sparsity_reverse_lut_rwm)
330
329
 
331
- return BlksprsTensor(output)
330
+ return BlksprsTensor(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
331
+ sparsity_reverse_lut_rwm, y, sparsity_block_size))
332
332
 
333
333
 
334
334
  def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
335
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
335
+ sparsity_block_size: int) -> BlksprsTensor:
336
336
  """Wrapper for ``row_wise_add`` with negated y.
337
337
 
338
338
  """
339
- return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size, triton_block_size)
340
-
341
-
339
+ return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size)
340
+
341
+
342
+ @triton_op("blksprs::row_wise_add_forward", mutates_args={})
343
+ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
344
+ sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
345
+ y: Tensor, sparsity_block_size: int) -> Tensor:
346
+ with torch.no_grad():
347
+ output = torch.zeros_like(x)
348
+
349
+ x_b, x_r, x_c = x.size()
350
+ x_b_s, x_r_s, x_c_s = stride(x)
351
+ s_lut_r, s_lut_c = sparsity_lut_x.size()
352
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
353
+ y_b, y_r, y_c = y.size()
354
+ y_b_s, y_r_s, y_c_s = stride(y)
355
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
356
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
357
+ o_b, o_r, o_c = output.size()
358
+ o_b_s, o_r_s, o_c_s = stride(output)
359
+
360
+ triton_grid = lambda meta: [o_b,
361
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
362
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
363
+
364
+ (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
365
+ (x,
366
+ x_b, x_b_s, x_r_s, x_c_s,
367
+ sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
368
+ y, y_b, y_b_s, y_r_s, y_c_s,
369
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
370
+ sparsity_reverse_x_lut_rwm,
371
+ output,
372
+ o_b, o_b_s, o_r_s, o_c_s,
373
+ sparsity_block_size))
374
+
375
+ return output
376
+
377
+
378
+ # noinspection PyUnusedLocal
379
+ @triton.autotune(
380
+ configs=get_autotune_configs(),
381
+ key=["sparsity_block_size"],
382
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
383
+ reset_to_zero=["o"]
384
+ )
342
385
  @triton.jit
343
386
  def kernel_blocksparse_row_wise_add(x,
344
387
  x_b, x_b_s, x_r_s, x_c_s,
345
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
388
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
346
389
  y, y_b, y_b_s, y_r_s, y_c_s,
347
390
  s_l_y_b, s_l_y_b_s, s_l_y_r_s,
348
391
  r_lut_y,
349
392
  o,
350
393
  o_b, o_b_s, o_r_s, o_c_s,
394
+ sparsity_block_size,
351
395
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
352
396
  # Get triton block indices
353
397
  pid_blk = tl.program_id(axis=0)
@@ -355,13 +399,13 @@ def kernel_blocksparse_row_wise_add(x,
355
399
  pid_col = tl.program_id(axis=2)
356
400
 
357
401
  # Get position of current sparsity block consisting of its batch and row index
358
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
359
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
360
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
402
+ spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
403
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
404
+ spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
361
405
 
362
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
363
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
364
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
406
+ spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
407
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
408
+ spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
365
409
 
366
410
  # Get reverse sparsity indices for s
367
411
  rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
@@ -377,14 +421,16 @@ def kernel_blocksparse_row_wise_add(x,
377
421
  blk_x_idx = ((pid_blk * x_b_s) +
378
422
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
379
423
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
380
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
424
+ blk_x_msk = (blk_x_idx >= 0 and
425
+ blk_x_idx < x_b * x_b_s)
381
426
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
382
427
 
383
428
  # Load sum block
384
429
  blk_s_idx = (rev_idx_spa_s * y_b_s +
385
430
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
386
431
  (tl.arange(0, 1) * y_c_s)[None, :])
387
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < y_b * y_b_s)
432
+ blk_s_msk = (blk_s_idx >= 0 and
433
+ blk_s_idx < y_b * y_b_s)
388
434
  blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
389
435
 
390
436
  # Compute exp
@@ -394,5 +440,6 @@ def kernel_blocksparse_row_wise_add(x,
394
440
  blk_o_idx = ((pid_blk * o_b_s) +
395
441
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
396
442
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
397
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
443
+ blk_o_msk = (blk_o_idx >= 0 and
444
+ blk_o_idx < o_b * o_b_s)
398
445
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)