blksprs 2.0rc7__py3-none-any.whl → 2.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -60,39 +60,40 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
60
60
  sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
61
61
 
62
62
 
63
- @triton_op("blksprs::row_wise_sum", mutates_args={})
63
+ @triton_op("blksprs::row_wise_sum_forward", mutates_args={})
64
64
  def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
65
65
  sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
66
66
  sparsity_block_size: int, n_sparse_blocks_output: int,
67
67
  flag_slice_only: bool = False) -> Tensor:
68
- output = torch.zeros(
69
- size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
70
- dtype=x.dtype, device=x.device)
71
-
72
- x_b, x_r, x_c = x.size()
73
- x_b_s, x_r_s, x_c_s = stride(x)
74
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
75
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
76
- o_b, o_r, o_c = output.size()
77
- o_b_s, o_r_s, o_c_s = stride(output)
78
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
79
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
80
-
81
- triton_grid = lambda meta: [x_b,
82
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
83
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
84
-
85
- (wrap_triton(row_wise_sum_kernel)[triton_grid]
86
- (x,
87
- x_b, x_b_s, x_r_s, x_c_s,
88
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
89
- output,
90
- o_b, o_b_s, o_r_s,
91
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
92
- sparsity_reverse_lut_output,
93
- sparsity_block_size))
94
-
95
- return output
68
+ with torch.no_grad():
69
+ output = torch.zeros(
70
+ size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
71
+ dtype=x.dtype, device=x.device)
72
+
73
+ x_b, x_r, x_c = x.size()
74
+ x_b_s, x_r_s, x_c_s = stride(x)
75
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
76
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
77
+ o_b, o_r, o_c = output.size()
78
+ o_b_s, o_r_s, o_c_s = stride(output)
79
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
80
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
81
+
82
+ triton_grid = lambda meta: [x_b,
83
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
84
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
85
+
86
+ (wrap_triton(row_wise_sum_kernel)[triton_grid]
87
+ (x,
88
+ x_b, x_b_s, x_r_s, x_c_s,
89
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
90
+ output,
91
+ o_b, o_b_s, o_r_s,
92
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
93
+ sparsity_reverse_lut_output,
94
+ sparsity_block_size))
95
+
96
+ return output
96
97
 
97
98
 
98
99
  # noinspection PyUnusedLocal
@@ -132,25 +133,22 @@ def row_wise_sum_kernel(x,
132
133
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
133
134
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
134
135
 
135
- if rev_idx_spa == -1:
136
- tl.device_assert(False)
137
- return
138
-
139
- blk_idx = ((pid_blk * x_b_s) +
140
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
141
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
142
- blk_msk = (blk_idx >= 0 and
143
- blk_idx < x_b * x_b_s)
144
- blk = tl.load(x + blk_idx, mask=blk_msk)
136
+ if rev_idx_spa >= 0:
137
+ blk_idx = ((pid_blk * x_b_s) +
138
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
139
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
140
+ blk_msk = (blk_idx >= 0 and
141
+ blk_idx < x_b * x_b_s)
142
+ blk = tl.load(x + blk_idx, mask=blk_msk)
145
143
 
146
- buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
144
+ buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
147
145
 
148
- o_idx = (rev_idx_spa * o_b_s +
149
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
150
- (tl.arange(0, 1))[None, :])
151
- o_msk = (o_idx >= 0 and
152
- o_idx < o_b * o_b_s)
153
- tl.atomic_add(o + o_idx, buf, o_msk)
146
+ o_idx = (rev_idx_spa * o_b_s +
147
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
+ (tl.arange(0, 1))[None, :])
149
+ o_msk = (o_idx >= 0 and
150
+ o_idx < o_b * o_b_s)
151
+ tl.atomic_add(o + o_idx, buf, o_msk)
154
152
 
155
153
 
156
154
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -176,7 +174,7 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
176
174
  of the input and the sparsity layout of the output tensor.
177
175
 
178
176
  """
179
- # TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376
177
+ # TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
180
178
  x = torch.where(x == -0.0, torch.tensor(0.0), x)
181
179
  x = x.contiguous()
182
180
 
@@ -204,41 +202,42 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
204
202
  n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
205
203
 
206
204
 
207
- @triton_op("blksprs::row_wise_max", mutates_args={})
205
+ @triton_op("blksprs::row_wise_max_forward", mutates_args={})
208
206
  def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
209
207
  sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
210
208
  sparsity_block_size: int, n_sparse_blocks_output: int,
211
209
  flag_slice_only: bool = False) -> Tensor:
212
- output = torch.full(size=(n_sparse_blocks_output,
213
- sparsity_block_size,
214
- 1 if flag_slice_only else sparsity_block_size),
215
- fill_value=torch.finfo(x.dtype).min,
216
- device=x.device)
217
-
218
- x_b, x_r, x_c = x.size()
219
- x_b_s, x_r_s, x_c_s = stride(x)
220
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
221
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
222
- o_b, o_r, o_c = output.size()
223
- o_b_s, o_r_s, o_c_s = stride(output)
224
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
225
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
226
-
227
- triton_grid = lambda meta: [x_b,
228
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
229
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
230
-
231
- (wrap_triton(row_wise_max_kernel)[triton_grid]
232
- (x,
233
- x_b, x_b_s, x_r_s, x_c_s,
234
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
235
- output,
236
- o_b, o_b_s, o_r_s,
237
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
238
- sparsity_reverse_lut_output,
239
- sparsity_block_size))
240
-
241
- return output
210
+ with torch.no_grad():
211
+ output = torch.full(size=(n_sparse_blocks_output,
212
+ sparsity_block_size,
213
+ 1 if flag_slice_only else sparsity_block_size),
214
+ fill_value=torch.finfo(x.dtype).min,
215
+ device=x.device)
216
+
217
+ x_b, x_r, x_c = x.size()
218
+ x_b_s, x_r_s, x_c_s = stride(x)
219
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
220
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
221
+ o_b, o_r, o_c = output.size()
222
+ o_b_s, o_r_s, o_c_s = stride(output)
223
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
224
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
225
+
226
+ triton_grid = lambda meta: [x_b,
227
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
228
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
229
+
230
+ (wrap_triton(row_wise_max_kernel)[triton_grid]
231
+ (x,
232
+ x_b, x_b_s, x_r_s, x_c_s,
233
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
234
+ output,
235
+ o_b, o_b_s, o_r_s,
236
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
237
+ sparsity_reverse_lut_output,
238
+ sparsity_block_size))
239
+
240
+ return output
242
241
 
243
242
 
244
243
  # noinspection PyUnusedLocal
@@ -278,25 +277,22 @@ def row_wise_max_kernel(x,
278
277
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
279
278
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
280
279
 
281
- if rev_idx_spa == -1:
282
- tl.device_assert(False)
283
- return
284
-
285
- blk_idx = ((pid_blk * x_b_s) +
286
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
287
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
288
- blk_msk = (blk_idx >= 0 and
289
- blk_idx < x_b * x_b_s)
290
- blk = tl.load(x + blk_idx, mask=blk_msk)
280
+ if rev_idx_spa >= 0:
281
+ blk_idx = ((pid_blk * x_b_s) +
282
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
283
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
286
+ blk = tl.load(x + blk_idx, mask=blk_msk)
291
287
 
292
- buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
288
+ buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
293
289
 
294
- o_idx = (rev_idx_spa * o_b_s +
295
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
296
- (tl.arange(0, 1))[None, :])
297
- o_msk = (o_idx >= 0 and
298
- o_idx < o_b * o_b_s)
299
- tl.atomic_max(o + o_idx, buf, o_msk)
290
+ o_idx = (rev_idx_spa * o_b_s +
291
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
292
+ (tl.arange(0, 1))[None, :])
293
+ o_msk = (o_idx >= 0 and
294
+ o_idx < o_b * o_b_s)
295
+ tl.atomic_max(o + o_idx, buf, o_msk)
300
296
 
301
297
 
302
298
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -343,41 +339,43 @@ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
343
339
  return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size)
344
340
 
345
341
 
346
- @triton_op("blksprs::row_wise_add", mutates_args={})
342
+ @triton_op("blksprs::row_wise_add_forward", mutates_args={})
347
343
  def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
348
344
  sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
349
345
  y: Tensor, sparsity_block_size: int) -> Tensor:
350
- output = torch.zeros_like(x)
351
-
352
- x_b, x_r, x_c = x.size()
353
- x_b_s, x_r_s, x_c_s = stride(x)
354
- s_lut_r, s_lut_c = sparsity_lut_x.size()
355
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
356
- y_b, y_r, y_c = y.size()
357
- y_b_s, y_r_s, y_c_s = stride(y)
358
- s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
359
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
360
- o_b, o_r, o_c = output.size()
361
- o_b_s, o_r_s, o_c_s = stride(output)
362
-
363
- triton_grid = lambda meta: [o_b,
364
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
365
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
366
-
367
- (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
368
- (x,
369
- x_b, x_b_s, x_r_s, x_c_s,
370
- sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
371
- y, y_b, y_b_s, y_r_s, y_c_s,
372
- s_l_y_b, s_l_y_b_s, s_l_y_r_s,
373
- sparsity_reverse_x_lut_rwm,
374
- output,
375
- o_b, o_b_s, o_r_s, o_c_s,
376
- sparsity_block_size))
377
-
378
- return output
346
+ with torch.no_grad():
347
+ output = torch.zeros_like(x)
348
+
349
+ x_b, x_r, x_c = x.size()
350
+ x_b_s, x_r_s, x_c_s = stride(x)
351
+ s_lut_r, s_lut_c = sparsity_lut_x.size()
352
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
353
+ y_b, y_r, y_c = y.size()
354
+ y_b_s, y_r_s, y_c_s = stride(y)
355
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
356
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
357
+ o_b, o_r, o_c = output.size()
358
+ o_b_s, o_r_s, o_c_s = stride(output)
359
+
360
+ triton_grid = lambda meta: [o_b,
361
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
362
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
363
+
364
+ (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
365
+ (x,
366
+ x_b, x_b_s, x_r_s, x_c_s,
367
+ sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
368
+ y, y_b, y_b_s, y_r_s, y_c_s,
369
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
370
+ sparsity_reverse_x_lut_rwm,
371
+ output,
372
+ o_b, o_b_s, o_r_s, o_c_s,
373
+ sparsity_block_size))
374
+
375
+ return output
379
376
 
380
377
 
378
+ # noinspection PyUnusedLocal
381
379
  @triton.autotune(
382
380
  configs=get_autotune_configs(),
383
381
  key=["sparsity_block_size"],
@@ -46,14 +46,15 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
46
46
  partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
47
47
 
48
48
 
49
- @triton_op("blksprs::split", mutates_args={})
49
+ @triton_op("blksprs::split_forward", mutates_args={})
50
50
  def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
51
51
  _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
52
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
53
- n_sparse_blocks)
52
+ with torch.no_grad():
53
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
54
+ n_sparse_blocks)
54
55
 
55
56
 
56
- def split_backward(ctx, grad_output):
57
+ def split_wrapper_backward(ctx, grad_output):
57
58
  sparsity_layout = ctx.saved_tensors[0]
58
59
  num_partitions = ctx.num_partitions
59
60
  dim = ctx.dim
@@ -109,7 +110,7 @@ def split_setup_context(ctx, inputs, output):
109
110
  ctx.sparsity_block_size = sparsity_block_size
110
111
 
111
112
 
112
- split_forward.register_autograd(split_backward, setup_context=split_setup_context)
113
+ split_forward.register_autograd(split_wrapper_backward, setup_context=split_setup_context)
113
114
 
114
115
 
115
116
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -150,14 +151,15 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
150
151
  partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
151
152
 
152
153
 
153
- @triton_op("blksprs::merge", mutates_args={})
154
+ @triton_op("blksprs::merge_forward", mutates_args={})
154
155
  def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
155
156
  _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
156
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
157
- n_sparse_blocks)
157
+ with torch.no_grad():
158
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
159
+ n_sparse_blocks)
158
160
 
159
161
 
160
- def merge_backward(ctx, grad_output):
162
+ def merge_wrapper_backward(ctx, grad_output):
161
163
  sparsity_layout = ctx.saved_tensors[0]
162
164
  num_partitions = ctx.num_partitions
163
165
  dim = ctx.dim
@@ -216,4 +218,4 @@ def merge_setup_context(ctx, inputs, output):
216
218
  ctx.sparsity_block_size = sparsity_block_size
217
219
 
218
220
 
219
- merge_forward.register_autograd(merge_backward, setup_context=merge_setup_context)
221
+ merge_forward.register_autograd(merge_wrapper_backward, setup_context=merge_setup_context)
blksprs/ops/repeat.py CHANGED
@@ -92,15 +92,16 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
92
92
  lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
93
93
 
94
94
 
95
- @triton_op("blksprs::repeat", mutates_args={})
95
+ @triton_op("blksprs::repeat_forward", mutates_args={})
96
96
  def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
97
97
  sparsity_reverse_lut: Tensor,
98
98
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
99
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
100
- n_sparse_blocks)
99
+ with torch.no_grad():
100
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
101
+ n_sparse_blocks)
101
102
 
102
103
 
103
- def repeat_backward(ctx, grad_output):
104
+ def repeat_wrapper_backward(ctx, grad_output):
104
105
  sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
105
106
  sparsity_block_size = ctx.sparsity_block_size
106
107
  n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
@@ -190,4 +191,4 @@ def repeat_setup_context(ctx, inputs, output):
190
191
  ctx.sparsity_block_size = sparsity_block_size
191
192
 
192
193
 
193
- repeat_forward.register_autograd(repeat_backward, setup_context=repeat_setup_context)
194
+ repeat_forward.register_autograd(repeat_wrapper_backward, setup_context=repeat_setup_context)
blksprs/ops/softmax.py CHANGED
@@ -47,19 +47,13 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
47
47
  sparsity_block_size))
48
48
 
49
49
 
50
- @triton_op("blksprs::softmax", mutates_args={})
50
+ @triton_op("blksprs::softmax_forward", mutates_args={})
51
51
  def softmax_forward(x: Tensor, sparsity_layout: Tensor,
52
52
  sparsity_lut: Tensor,
53
53
  sparsity_reverse_lut_rws: Tensor,
54
54
  sparsity_block_size: int) -> Tensor:
55
55
  output = torch.zeros_like(x)
56
56
 
57
- x_b, x_r, x_c = x.size()
58
- x_b_s, x_r_s, x_c_s = stride(x)
59
- s_lut_r, s_lut_c = sparsity_lut.size()
60
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
61
- o_b, o_r, o_c = output.size()
62
-
63
57
  x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
64
58
  flag_slice_only=True)
65
59
  x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
@@ -67,6 +61,11 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
67
61
  x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
68
62
  flag_slice_only=True)
69
63
 
64
+ x_b, x_r, x_c = x.size()
65
+ x_b_s, x_r_s, x_c_s = stride(x)
66
+ s_lut_r, s_lut_c = sparsity_lut.size()
67
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
68
+ o_b, o_r, o_c = output.size()
70
69
  s_b, s_r, s_c = x_exp_row_wise_sum.shape
71
70
  s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
72
71
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
@@ -89,50 +88,58 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
89
88
  return output
90
89
 
91
90
 
92
- def softmax_backward(ctx, grad_output):
91
+ def softmax_backward_wrapper(ctx, grad_output):
93
92
  o, sparsity_layout, sparsity_lut = ctx.saved_tensors
94
93
  sparsity_block_size = ctx.sparsity_block_size
95
94
 
96
- s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
97
-
98
- sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
99
- sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
100
- (sparsity_layout_s_flat == 1) -
101
- (1 * (sparsity_layout_s_flat == 0)))
102
-
103
- o_b, o_r, o_c = o.size()
104
- o_b_s, o_r_s, o_c_s = stride(o)
105
- s_lut_r, s_lut_c = sparsity_lut.size()
106
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
107
- s_b, s_r, s_c = s.size()
108
- s_b_s, s_r_s, s_c_s = stride(s)
109
- s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
110
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
111
-
112
- grad_x = torch.zeros_like(o, dtype=torch.float)
113
-
114
- triton_grid = lambda meta: [o_b,
115
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
116
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
117
-
118
- # TODO wrap
119
- (softmax_kernel_grad[triton_grid]
120
- (grad_output,
121
- o_b, o_b_s, o_r_s, o_c_s,
122
- o,
123
- o_b, o_b_s, o_r_s, o_c_s,
124
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
125
- s,
126
- s_b, s_b_s, s_r_s, s_c_s,
127
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
128
- sparsity_reverse_lut_s,
129
- grad_x,
130
- o_b, o_b_s, o_r_s, o_c_s,
131
- sparsity_block_size))
132
-
133
- return grad_x, None, None, None, None, None
95
+ return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
96
+ sparsity_block_size), None, None, None, None, None
97
+
98
+
99
+ @triton_op("blksprs::softmax_backward", mutates_args={})
100
+ def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
101
+ sparsity_block_size: int) -> Tensor:
102
+ with torch.no_grad():
103
+ s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
104
+
105
+ sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
106
+ sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
107
+ (sparsity_layout_s_flat == 1) -
108
+ (1 * (sparsity_layout_s_flat == 0)))
109
+
110
+ o_b, o_r, o_c = o.size()
111
+ o_b_s, o_r_s, o_c_s = stride(o)
112
+ s_lut_r, s_lut_c = sparsity_lut.size()
113
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
114
+ s_b, s_r, s_c = s.size()
115
+ s_b_s, s_r_s, s_c_s = stride(s)
116
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
117
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
118
+
119
+ grad_x = torch.zeros_like(o, dtype=torch.float)
120
+
121
+ triton_grid = lambda meta: [o_b,
122
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
123
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
124
+
125
+ (wrap_triton(softmax_kernel_grad)[triton_grid]
126
+ (grad_output,
127
+ o_b, o_b_s, o_r_s, o_c_s,
128
+ o,
129
+ o_b, o_b_s, o_r_s, o_c_s,
130
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
131
+ s,
132
+ s_b, s_b_s, s_r_s, s_c_s,
133
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
134
+ sparsity_reverse_lut_s,
135
+ grad_x,
136
+ o_b, o_b_s, o_r_s, o_c_s,
137
+ sparsity_block_size))
138
+
139
+ return grad_x
134
140
 
135
141
 
142
+ # noinspection PyUnusedLocal
136
143
  @triton.autotune(
137
144
  configs=get_autotune_configs(),
138
145
  key=["sparsity_block_size"],
@@ -193,6 +200,7 @@ def softmax_kernel(x,
193
200
  tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
194
201
 
195
202
 
203
+ # noinspection PyUnusedLocal
196
204
  @triton.autotune(
197
205
  configs=get_autotune_configs(),
198
206
  key=["sparsity_block_size"],
@@ -293,4 +301,4 @@ def softmax_setup_context(ctx, inputs, output):
293
301
  ctx.sparsity_block_size = sparsity_block_size
294
302
 
295
303
 
296
- softmax_forward.register_autograd(softmax_backward, setup_context=softmax_setup_context)
304
+ softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
blksprs/ops/transpose.py CHANGED
@@ -28,7 +28,6 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
28
28
 
29
29
  """
30
30
  x = x.contiguous()
31
- x_t = x.transpose(-1, -2).contiguous()
32
31
 
33
32
  validate_dimensions(x)
34
33
  validate_contiguous(x)
@@ -38,20 +37,22 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
38
37
 
39
38
  lut = transpose_build_lut(lut, sparsity_layout)
40
39
 
41
- return BlksprsTensor(transpose_forward(x_t, lut["sparsity_layout_t"],
40
+ return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
42
41
  lut["sparsity_lut"], lut["sparsity_reverse_lut"],
43
42
  sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
44
43
 
45
44
 
46
- @triton_op("blksprs::transpose", mutates_args={})
45
+ @triton_op("blksprs::transpose_forward", mutates_args={})
47
46
  def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
48
47
  sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
49
48
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
50
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
51
- sparsity_block_size, n_sparse_blocks)
49
+ with torch.no_grad():
50
+ x_t = x.transpose(-1, -2).contiguous()
51
+ return flow_pull_forward(x_t, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
52
+ sparsity_block_size, n_sparse_blocks)
52
53
 
53
54
 
54
- def transpose_backward(ctx, grad_output):
55
+ def transpose_wrapper_backward(ctx, grad_output):
55
56
  sparsity_layout = ctx.saved_tensors[0]
56
57
  sparsity_block_size = ctx.sparsity_block_size
57
58
 
@@ -96,4 +97,4 @@ def transpose_setup_context(ctx, inputs, output):
96
97
  ctx.sparsity_block_size = sparsity_block_size
97
98
 
98
99
 
99
- transpose_forward.register_autograd(transpose_backward, setup_context=transpose_setup_context)
100
+ transpose_forward.register_autograd(transpose_wrapper_backward, setup_context=transpose_setup_context)
@@ -2,15 +2,7 @@ import os
2
2
 
3
3
  blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
4
4
 
5
- if blksprs_autotune_mode == "TEST":
6
- autotune_parameters = [
7
- (16, 3, 8),
8
-
9
- (32, 3, 8),
10
-
11
- (64, 3, 8),
12
- ]
13
- elif blksprs_autotune_mode == "DEFAULT":
5
+ if blksprs_autotune_mode == "DEFAULT":
14
6
  autotune_parameters = [
15
7
  (16, 3, 8),
16
8
  (16, 4, 4),
@@ -28,6 +20,14 @@ elif blksprs_autotune_mode == "DEFAULT":
28
20
  (128, 4, 4),
29
21
  (128, 5, 2),
30
22
  ]
23
+ elif blksprs_autotune_mode == "TEST":
24
+ autotune_parameters = [
25
+ (16, 3, 8),
26
+
27
+ (32, 3, 8),
28
+
29
+ (64, 3, 8),
30
+ ]
31
31
  else:
32
32
  raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
33
33
 
@@ -75,4 +75,4 @@ def get_autotune_configs():
75
75
  autotune_configs.append(
76
76
  triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
77
77
 
78
- return autotune_configs
78
+ return autotune_configs
@@ -26,7 +26,6 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
26
26
 
27
27
  # Apply weights
28
28
  sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
29
- # TODO At the moment, manual cast is needed. Bug with custom_fwd?
30
29
  xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
31
30
  interim = xw
32
31
 
blksprs/utils/tools.py CHANGED
@@ -1,3 +1,6 @@
1
+ import tomllib
2
+ from pathlib import Path
3
+
1
4
  import torch
2
5
  from torch import Tensor, Size
3
6
 
@@ -5,6 +8,11 @@ from torch import Tensor, Size
5
8
  torch._dynamo.config.capture_scalar_outputs = True
6
9
 
7
10
 
11
+ def version():
12
+ with open(Path(__file__).parent.parent.parent.joinpath("pyproject.toml"), "rb") as f:
13
+ return tomllib.load(f)["project"]["version"]
14
+
15
+
8
16
  def do_shape_blocksparse(x: Tensor):
9
17
  if x.dim() == 3:
10
18
  return x.contiguous(), x.size()