blksprs 2.0rc6__py3-none-any.whl → 2.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,9 +4,9 @@ from torch import Tensor
4
4
  from torch._library.triton import wrap_triton, triton_op
5
5
  from triton import language as tl
6
6
 
7
- from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import stride, get_autocast_min_val
9
7
  from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
9
+ from blksprs.utils.tools import stride
10
10
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
11
11
  validate_sparsity_block_size
12
12
 
@@ -60,41 +60,43 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
60
60
  sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
61
61
 
62
62
 
63
- @triton_op("blksprs::row_wise_sum", mutates_args={})
63
+ @triton_op("blksprs::row_wise_sum_forward", mutates_args={})
64
64
  def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
65
65
  sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
66
66
  sparsity_block_size: int, n_sparse_blocks_output: int,
67
67
  flag_slice_only: bool = False) -> Tensor:
68
- output = torch.zeros(
69
- size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
70
- dtype=x.dtype, device=x.device)
71
-
72
- x_b, x_r, x_c = x.size()
73
- x_b_s, x_r_s, x_c_s = stride(x)
74
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
75
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
76
- o_b, o_r, o_c = output.size()
77
- o_b_s, o_r_s, o_c_s = stride(output)
78
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
79
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
80
-
81
- triton_grid = lambda meta: [x_b,
82
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
83
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
84
-
85
- (wrap_triton(row_wise_sum_kernel)[triton_grid]
86
- (x,
87
- x_b, x_b_s, x_r_s, x_c_s,
88
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
89
- output,
90
- o_b, o_b_s, o_r_s,
91
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
92
- sparsity_reverse_lut_output,
93
- sparsity_block_size))
94
-
95
- return output
96
-
97
-
68
+ with torch.no_grad():
69
+ output = torch.zeros(
70
+ size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
71
+ dtype=x.dtype, device=x.device)
72
+
73
+ x_b, x_r, x_c = x.size()
74
+ x_b_s, x_r_s, x_c_s = stride(x)
75
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
76
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
77
+ o_b, o_r, o_c = output.size()
78
+ o_b_s, o_r_s, o_c_s = stride(output)
79
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
80
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
81
+
82
+ triton_grid = lambda meta: [x_b,
83
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
84
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
85
+
86
+ (wrap_triton(row_wise_sum_kernel)[triton_grid]
87
+ (x,
88
+ x_b, x_b_s, x_r_s, x_c_s,
89
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
90
+ output,
91
+ o_b, o_b_s, o_r_s,
92
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
93
+ sparsity_reverse_lut_output,
94
+ sparsity_block_size))
95
+
96
+ return output
97
+
98
+
99
+ # noinspection PyUnusedLocal
98
100
  @triton.autotune(
99
101
  configs=get_autotune_configs(),
100
102
  key=["sparsity_block_size"],
@@ -131,25 +133,22 @@ def row_wise_sum_kernel(x,
131
133
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
132
134
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
133
135
 
134
- if rev_idx_spa == -1:
135
- tl.device_assert(False)
136
- return
136
+ if rev_idx_spa >= 0:
137
+ blk_idx = ((pid_blk * x_b_s) +
138
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
139
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
140
+ blk_msk = (blk_idx >= 0 and
141
+ blk_idx < x_b * x_b_s)
142
+ blk = tl.load(x + blk_idx, mask=blk_msk)
137
143
 
138
- blk_idx = ((pid_blk * x_b_s) +
139
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
140
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
141
- blk_msk = (blk_idx >= 0 and
142
- blk_idx < x_b * x_b_s)
143
- blk = tl.load(x + blk_idx, mask=blk_msk)
144
+ buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
144
145
 
145
- buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
146
-
147
- o_idx = (rev_idx_spa * o_b_s +
148
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
149
- (tl.arange(0, 1))[None, :])
150
- o_msk = (o_idx >= 0 and
151
- o_idx < o_b * o_b_s)
152
- tl.atomic_add(o + o_idx, buf, o_msk)
146
+ o_idx = (rev_idx_spa * o_b_s +
147
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
+ (tl.arange(0, 1))[None, :])
149
+ o_msk = (o_idx >= 0 and
150
+ o_idx < o_b * o_b_s)
151
+ tl.atomic_add(o + o_idx, buf, o_msk)
153
152
 
154
153
 
155
154
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -175,6 +174,8 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
175
174
  of the input and the sparsity layout of the output tensor.
176
175
 
177
176
  """
177
+ # TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
178
+ x = torch.where(x == -0.0, torch.tensor(0.0), x)
178
179
  x = x.contiguous()
179
180
 
180
181
  validate_dimensions(x)
@@ -201,43 +202,45 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
201
202
  n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
202
203
 
203
204
 
204
- @triton_op("blksprs::row_wise_max", mutates_args={})
205
+ @triton_op("blksprs::row_wise_max_forward", mutates_args={})
205
206
  def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
206
207
  sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
207
208
  sparsity_block_size: int, n_sparse_blocks_output: int,
208
209
  flag_slice_only: bool = False) -> Tensor:
209
- output = torch.full(size=(n_sparse_blocks_output,
210
- sparsity_block_size,
211
- 1 if flag_slice_only else sparsity_block_size),
212
- fill_value=get_autocast_min_val(),
213
- device=x.device)
214
-
215
- x_b, x_r, x_c = x.size()
216
- x_b_s, x_r_s, x_c_s = stride(x)
217
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
218
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
219
- o_b, o_r, o_c = output.size()
220
- o_b_s, o_r_s, o_c_s = stride(output)
221
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
222
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
223
-
224
- triton_grid = lambda meta: [x_b,
225
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
226
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
227
-
228
- (wrap_triton(row_wise_max_kernel)[triton_grid]
229
- (x,
230
- x_b, x_b_s, x_r_s, x_c_s,
231
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
232
- output,
233
- o_b, o_b_s, o_r_s,
234
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
235
- sparsity_reverse_lut_output,
236
- sparsity_block_size))
237
-
238
- return output
239
-
240
-
210
+ with torch.no_grad():
211
+ output = torch.full(size=(n_sparse_blocks_output,
212
+ sparsity_block_size,
213
+ 1 if flag_slice_only else sparsity_block_size),
214
+ fill_value=torch.finfo(x.dtype).min,
215
+ device=x.device)
216
+
217
+ x_b, x_r, x_c = x.size()
218
+ x_b_s, x_r_s, x_c_s = stride(x)
219
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
220
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
221
+ o_b, o_r, o_c = output.size()
222
+ o_b_s, o_r_s, o_c_s = stride(output)
223
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
224
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
225
+
226
+ triton_grid = lambda meta: [x_b,
227
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
228
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
229
+
230
+ (wrap_triton(row_wise_max_kernel)[triton_grid]
231
+ (x,
232
+ x_b, x_b_s, x_r_s, x_c_s,
233
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
234
+ output,
235
+ o_b, o_b_s, o_r_s,
236
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
237
+ sparsity_reverse_lut_output,
238
+ sparsity_block_size))
239
+
240
+ return output
241
+
242
+
243
+ # noinspection PyUnusedLocal
241
244
  @triton.autotune(
242
245
  configs=get_autotune_configs(),
243
246
  key=["sparsity_block_size"],
@@ -274,25 +277,22 @@ def row_wise_max_kernel(x,
274
277
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
275
278
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
276
279
 
277
- if rev_idx_spa == -1:
278
- tl.device_assert(False)
279
- return
280
-
281
- blk_idx = ((pid_blk * x_b_s) +
282
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
283
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
284
- blk_msk = (blk_idx >= 0 and
285
- blk_idx < x_b * x_b_s)
286
- blk = tl.load(x + blk_idx, mask=blk_msk)
280
+ if rev_idx_spa >= 0:
281
+ blk_idx = ((pid_blk * x_b_s) +
282
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
283
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
286
+ blk = tl.load(x + blk_idx, mask=blk_msk)
287
287
 
288
- buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
288
+ buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
289
289
 
290
- o_idx = (rev_idx_spa * o_b_s +
291
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
292
- (tl.arange(0, 1))[None, :])
293
- o_msk = (o_idx >= 0 and
294
- o_idx < o_b * o_b_s)
295
- tl.atomic_max(o + o_idx, buf, o_msk)
290
+ o_idx = (rev_idx_spa * o_b_s +
291
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
292
+ (tl.arange(0, 1))[None, :])
293
+ o_msk = (o_idx >= 0 and
294
+ o_idx < o_b * o_b_s)
295
+ tl.atomic_max(o + o_idx, buf, o_msk)
296
296
 
297
297
 
298
298
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -339,41 +339,43 @@ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
339
339
  return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size)
340
340
 
341
341
 
342
- @triton_op("blksprs::row_wise_add", mutates_args={})
342
+ @triton_op("blksprs::row_wise_add_forward", mutates_args={})
343
343
  def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
344
344
  sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
345
345
  y: Tensor, sparsity_block_size: int) -> Tensor:
346
- output = torch.zeros_like(x)
347
-
348
- x_b, x_r, x_c = x.size()
349
- x_b_s, x_r_s, x_c_s = stride(x)
350
- s_lut_r, s_lut_c = sparsity_lut_x.size()
351
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
352
- y_b, y_r, y_c = y.size()
353
- y_b_s, y_r_s, y_c_s = stride(y)
354
- s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
355
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
356
- o_b, o_r, o_c = output.size()
357
- o_b_s, o_r_s, o_c_s = stride(output)
358
-
359
- triton_grid = lambda meta: [o_b,
360
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
361
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
362
-
363
- (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
364
- (x,
365
- x_b, x_b_s, x_r_s, x_c_s,
366
- sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
367
- y, y_b, y_b_s, y_r_s, y_c_s,
368
- s_l_y_b, s_l_y_b_s, s_l_y_r_s,
369
- sparsity_reverse_x_lut_rwm,
370
- output,
371
- o_b, o_b_s, o_r_s, o_c_s,
372
- sparsity_block_size))
373
-
374
- return output
375
-
376
-
346
+ with torch.no_grad():
347
+ output = torch.zeros_like(x)
348
+
349
+ x_b, x_r, x_c = x.size()
350
+ x_b_s, x_r_s, x_c_s = stride(x)
351
+ s_lut_r, s_lut_c = sparsity_lut_x.size()
352
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
353
+ y_b, y_r, y_c = y.size()
354
+ y_b_s, y_r_s, y_c_s = stride(y)
355
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
356
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
357
+ o_b, o_r, o_c = output.size()
358
+ o_b_s, o_r_s, o_c_s = stride(output)
359
+
360
+ triton_grid = lambda meta: [o_b,
361
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
362
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
363
+
364
+ (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
365
+ (x,
366
+ x_b, x_b_s, x_r_s, x_c_s,
367
+ sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
368
+ y, y_b, y_b_s, y_r_s, y_c_s,
369
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
370
+ sparsity_reverse_x_lut_rwm,
371
+ output,
372
+ o_b, o_b_s, o_r_s, o_c_s,
373
+ sparsity_block_size))
374
+
375
+ return output
376
+
377
+
378
+ # noinspection PyUnusedLocal
377
379
  @triton.autotune(
378
380
  configs=get_autotune_configs(),
379
381
  key=["sparsity_block_size"],
@@ -46,14 +46,15 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
46
46
  partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
47
47
 
48
48
 
49
- @triton_op("blksprs::split", mutates_args={})
49
+ @triton_op("blksprs::split_forward", mutates_args={})
50
50
  def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
51
51
  _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
52
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
53
- n_sparse_blocks)
52
+ with torch.no_grad():
53
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
54
+ n_sparse_blocks)
54
55
 
55
56
 
56
- def split_backward(ctx, grad_output):
57
+ def split_wrapper_backward(ctx, grad_output):
57
58
  sparsity_layout = ctx.saved_tensors[0]
58
59
  num_partitions = ctx.num_partitions
59
60
  dim = ctx.dim
@@ -109,7 +110,7 @@ def split_setup_context(ctx, inputs, output):
109
110
  ctx.sparsity_block_size = sparsity_block_size
110
111
 
111
112
 
112
- split_forward.register_autograd(split_backward, setup_context=split_setup_context)
113
+ split_forward.register_autograd(split_wrapper_backward, setup_context=split_setup_context)
113
114
 
114
115
 
115
116
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -150,14 +151,15 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
150
151
  partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
151
152
 
152
153
 
153
- @triton_op("blksprs::merge", mutates_args={})
154
+ @triton_op("blksprs::merge_forward", mutates_args={})
154
155
  def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
155
156
  _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
156
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
157
- n_sparse_blocks)
157
+ with torch.no_grad():
158
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
159
+ n_sparse_blocks)
158
160
 
159
161
 
160
- def merge_backward(ctx, grad_output):
162
+ def merge_wrapper_backward(ctx, grad_output):
161
163
  sparsity_layout = ctx.saved_tensors[0]
162
164
  num_partitions = ctx.num_partitions
163
165
  dim = ctx.dim
@@ -216,4 +218,4 @@ def merge_setup_context(ctx, inputs, output):
216
218
  ctx.sparsity_block_size = sparsity_block_size
217
219
 
218
220
 
219
- merge_forward.register_autograd(merge_backward, setup_context=merge_setup_context)
221
+ merge_forward.register_autograd(merge_wrapper_backward, setup_context=merge_setup_context)
blksprs/ops/repeat.py CHANGED
@@ -92,15 +92,16 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
92
92
  lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
93
93
 
94
94
 
95
- @triton_op("blksprs::repeat", mutates_args={})
95
+ @triton_op("blksprs::repeat_forward", mutates_args={})
96
96
  def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
97
97
  sparsity_reverse_lut: Tensor,
98
98
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
99
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
100
- n_sparse_blocks)
99
+ with torch.no_grad():
100
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
101
+ n_sparse_blocks)
101
102
 
102
103
 
103
- def repeat_backward(ctx, grad_output):
104
+ def repeat_wrapper_backward(ctx, grad_output):
104
105
  sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
105
106
  sparsity_block_size = ctx.sparsity_block_size
106
107
  n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
@@ -190,4 +191,4 @@ def repeat_setup_context(ctx, inputs, output):
190
191
  ctx.sparsity_block_size = sparsity_block_size
191
192
 
192
193
 
193
- repeat_forward.register_autograd(repeat_backward, setup_context=repeat_setup_context)
194
+ repeat_forward.register_autograd(repeat_wrapper_backward, setup_context=repeat_setup_context)
blksprs/ops/softmax.py CHANGED
@@ -47,19 +47,13 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
47
47
  sparsity_block_size))
48
48
 
49
49
 
50
- @triton_op("blksprs::softmax", mutates_args={})
50
+ @triton_op("blksprs::softmax_forward", mutates_args={})
51
51
  def softmax_forward(x: Tensor, sparsity_layout: Tensor,
52
52
  sparsity_lut: Tensor,
53
53
  sparsity_reverse_lut_rws: Tensor,
54
54
  sparsity_block_size: int) -> Tensor:
55
55
  output = torch.zeros_like(x)
56
56
 
57
- x_b, x_r, x_c = x.size()
58
- x_b_s, x_r_s, x_c_s = stride(x)
59
- s_lut_r, s_lut_c = sparsity_lut.size()
60
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
61
- o_b, o_r, o_c = output.size()
62
-
63
57
  x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
64
58
  flag_slice_only=True)
65
59
  x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
@@ -67,6 +61,11 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
67
61
  x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
68
62
  flag_slice_only=True)
69
63
 
64
+ x_b, x_r, x_c = x.size()
65
+ x_b_s, x_r_s, x_c_s = stride(x)
66
+ s_lut_r, s_lut_c = sparsity_lut.size()
67
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
68
+ o_b, o_r, o_c = output.size()
70
69
  s_b, s_r, s_c = x_exp_row_wise_sum.shape
71
70
  s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
72
71
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
@@ -89,50 +88,58 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
89
88
  return output
90
89
 
91
90
 
92
- def softmax_backward(ctx, grad_output):
91
+ def softmax_backward_wrapper(ctx, grad_output):
93
92
  o, sparsity_layout, sparsity_lut = ctx.saved_tensors
94
93
  sparsity_block_size = ctx.sparsity_block_size
95
94
 
96
- s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
97
-
98
- sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
99
- sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
100
- (sparsity_layout_s_flat == 1) -
101
- (1 * (sparsity_layout_s_flat == 0)))
102
-
103
- o_b, o_r, o_c = o.size()
104
- o_b_s, o_r_s, o_c_s = stride(o)
105
- s_lut_r, s_lut_c = sparsity_lut.size()
106
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
107
- s_b, s_r, s_c = s.size()
108
- s_b_s, s_r_s, s_c_s = stride(s)
109
- s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
110
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
111
-
112
- grad_x = torch.zeros_like(o, dtype=torch.float)
113
-
114
- triton_grid = lambda meta: [o_b,
115
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
116
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
117
-
118
- # TODO wrap
119
- (softmax_kernel_grad[triton_grid]
120
- (grad_output,
121
- o_b, o_b_s, o_r_s, o_c_s,
122
- o,
123
- o_b, o_b_s, o_r_s, o_c_s,
124
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
125
- s,
126
- s_b, s_b_s, s_r_s, s_c_s,
127
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
128
- sparsity_reverse_lut_s,
129
- grad_x,
130
- o_b, o_b_s, o_r_s, o_c_s,
131
- sparsity_block_size))
132
-
133
- return grad_x, None, None, None, None, None
95
+ return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
96
+ sparsity_block_size), None, None, None, None, None
97
+
98
+
99
+ @triton_op("blksprs::softmax_backward", mutates_args={})
100
+ def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
101
+ sparsity_block_size: int) -> Tensor:
102
+ with torch.no_grad():
103
+ s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
104
+
105
+ sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
106
+ sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
107
+ (sparsity_layout_s_flat == 1) -
108
+ (1 * (sparsity_layout_s_flat == 0)))
109
+
110
+ o_b, o_r, o_c = o.size()
111
+ o_b_s, o_r_s, o_c_s = stride(o)
112
+ s_lut_r, s_lut_c = sparsity_lut.size()
113
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
114
+ s_b, s_r, s_c = s.size()
115
+ s_b_s, s_r_s, s_c_s = stride(s)
116
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
117
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
118
+
119
+ grad_x = torch.zeros_like(o, dtype=torch.float)
120
+
121
+ triton_grid = lambda meta: [o_b,
122
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
123
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
124
+
125
+ (wrap_triton(softmax_kernel_grad)[triton_grid]
126
+ (grad_output,
127
+ o_b, o_b_s, o_r_s, o_c_s,
128
+ o,
129
+ o_b, o_b_s, o_r_s, o_c_s,
130
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
131
+ s,
132
+ s_b, s_b_s, s_r_s, s_c_s,
133
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
134
+ sparsity_reverse_lut_s,
135
+ grad_x,
136
+ o_b, o_b_s, o_r_s, o_c_s,
137
+ sparsity_block_size))
138
+
139
+ return grad_x
134
140
 
135
141
 
142
+ # noinspection PyUnusedLocal
136
143
  @triton.autotune(
137
144
  configs=get_autotune_configs(),
138
145
  key=["sparsity_block_size"],
@@ -193,6 +200,7 @@ def softmax_kernel(x,
193
200
  tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
194
201
 
195
202
 
203
+ # noinspection PyUnusedLocal
196
204
  @triton.autotune(
197
205
  configs=get_autotune_configs(),
198
206
  key=["sparsity_block_size"],
@@ -293,4 +301,4 @@ def softmax_setup_context(ctx, inputs, output):
293
301
  ctx.sparsity_block_size = sparsity_block_size
294
302
 
295
303
 
296
- softmax_forward.register_autograd(softmax_backward, setup_context=softmax_setup_context)
304
+ softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
blksprs/ops/transpose.py CHANGED
@@ -28,7 +28,6 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
28
28
 
29
29
  """
30
30
  x = x.contiguous()
31
- x_t = x.transpose(-1, -2).contiguous()
32
31
 
33
32
  validate_dimensions(x)
34
33
  validate_contiguous(x)
@@ -38,20 +37,22 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
38
37
 
39
38
  lut = transpose_build_lut(lut, sparsity_layout)
40
39
 
41
- return BlksprsTensor(transpose_forward(x_t, lut["sparsity_layout_t"],
40
+ return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
42
41
  lut["sparsity_lut"], lut["sparsity_reverse_lut"],
43
42
  sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
44
43
 
45
44
 
46
- @triton_op("blksprs::transpose", mutates_args={})
45
+ @triton_op("blksprs::transpose_forward", mutates_args={})
47
46
  def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
48
47
  sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
49
48
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
50
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
51
- sparsity_block_size, n_sparse_blocks)
49
+ with torch.no_grad():
50
+ x_t = x.transpose(-1, -2).contiguous()
51
+ return flow_pull_forward(x_t, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
52
+ sparsity_block_size, n_sparse_blocks)
52
53
 
53
54
 
54
- def transpose_backward(ctx, grad_output):
55
+ def transpose_wrapper_backward(ctx, grad_output):
55
56
  sparsity_layout = ctx.saved_tensors[0]
56
57
  sparsity_block_size = ctx.sparsity_block_size
57
58
 
@@ -96,4 +97,4 @@ def transpose_setup_context(ctx, inputs, output):
96
97
  ctx.sparsity_block_size = sparsity_block_size
97
98
 
98
99
 
99
- transpose_forward.register_autograd(transpose_backward, setup_context=transpose_setup_context)
100
+ transpose_forward.register_autograd(transpose_wrapper_backward, setup_context=transpose_setup_context)
@@ -2,15 +2,7 @@ import os
2
2
 
3
3
  blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
4
4
 
5
- if blksprs_autotune_mode == "TEST":
6
- autotune_parameters = [
7
- (16, 3, 8),
8
-
9
- (32, 3, 8),
10
-
11
- (64, 3, 8),
12
- ]
13
- elif blksprs_autotune_mode == "DEFAULT":
5
+ if blksprs_autotune_mode == "DEFAULT":
14
6
  autotune_parameters = [
15
7
  (16, 3, 8),
16
8
  (16, 4, 4),
@@ -28,6 +20,14 @@ elif blksprs_autotune_mode == "DEFAULT":
28
20
  (128, 4, 4),
29
21
  (128, 5, 2),
30
22
  ]
23
+ elif blksprs_autotune_mode == "TEST":
24
+ autotune_parameters = [
25
+ (16, 3, 8),
26
+
27
+ (32, 3, 8),
28
+
29
+ (64, 3, 8),
30
+ ]
31
31
  else:
32
32
  raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
33
33
 
@@ -75,4 +75,4 @@ def get_autotune_configs():
75
75
  autotune_configs.append(
76
76
  triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
77
77
 
78
- return autotune_configs
78
+ return autotune_configs
@@ -26,7 +26,6 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
26
26
 
27
27
  # Apply weights
28
28
  sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
29
- # TODO At the moment, manual cast is needed. Bug with custom_fwd?
30
29
  xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
31
30
  interim = xw
32
31