blksprs 2.0rc7__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -60,39 +60,40 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
60
60
  sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
61
61
 
62
62
 
63
- @triton_op("blksprs::row_wise_sum", mutates_args={})
63
+ @triton_op("blksprs::row_wise_sum_forward", mutates_args={})
64
64
  def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
65
65
  sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
66
66
  sparsity_block_size: int, n_sparse_blocks_output: int,
67
67
  flag_slice_only: bool = False) -> Tensor:
68
- output = torch.zeros(
69
- size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
70
- dtype=x.dtype, device=x.device)
71
-
72
- x_b, x_r, x_c = x.size()
73
- x_b_s, x_r_s, x_c_s = stride(x)
74
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
75
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
76
- o_b, o_r, o_c = output.size()
77
- o_b_s, o_r_s, o_c_s = stride(output)
78
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
79
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
80
-
81
- triton_grid = lambda meta: [x_b,
82
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
83
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
84
-
85
- (wrap_triton(row_wise_sum_kernel)[triton_grid]
86
- (x,
87
- x_b, x_b_s, x_r_s, x_c_s,
88
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
89
- output,
90
- o_b, o_b_s, o_r_s,
91
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
92
- sparsity_reverse_lut_output,
93
- sparsity_block_size))
94
-
95
- return output
68
+ with torch.no_grad():
69
+ output = torch.zeros(
70
+ size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
71
+ dtype=x.dtype, device=x.device)
72
+
73
+ x_b, x_r, x_c = x.size()
74
+ x_b_s, x_r_s, x_c_s = stride(x)
75
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
76
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
77
+ o_b, o_r, o_c = output.size()
78
+ o_b_s, o_r_s, o_c_s = stride(output)
79
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
80
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
81
+
82
+ triton_grid = lambda meta: [x_b,
83
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
84
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
85
+
86
+ (wrap_triton(row_wise_sum_kernel)[triton_grid]
87
+ (x,
88
+ x_b, x_b_s, x_r_s, x_c_s,
89
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
90
+ output,
91
+ o_b, o_b_s, o_r_s,
92
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
93
+ sparsity_reverse_lut_output,
94
+ sparsity_block_size))
95
+
96
+ return output
96
97
 
97
98
 
98
99
  # noinspection PyUnusedLocal
@@ -132,25 +133,22 @@ def row_wise_sum_kernel(x,
132
133
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
133
134
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
134
135
 
135
- if rev_idx_spa == -1:
136
- tl.device_assert(False)
137
- return
138
-
139
- blk_idx = ((pid_blk * x_b_s) +
140
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
141
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
142
- blk_msk = (blk_idx >= 0 and
143
- blk_idx < x_b * x_b_s)
144
- blk = tl.load(x + blk_idx, mask=blk_msk)
136
+ if rev_idx_spa >= 0:
137
+ blk_idx = ((pid_blk * x_b_s) +
138
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
139
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
140
+ blk_msk = (blk_idx >= 0 and
141
+ blk_idx < x_b * x_b_s)
142
+ blk = tl.load(x + blk_idx, mask=blk_msk)
145
143
 
146
- buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
144
+ buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
147
145
 
148
- o_idx = (rev_idx_spa * o_b_s +
149
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
150
- (tl.arange(0, 1))[None, :])
151
- o_msk = (o_idx >= 0 and
152
- o_idx < o_b * o_b_s)
153
- tl.atomic_add(o + o_idx, buf, o_msk)
146
+ o_idx = (rev_idx_spa * o_b_s +
147
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
+ (tl.arange(0, 1))[None, :])
149
+ o_msk = (o_idx >= 0 and
150
+ o_idx < o_b * o_b_s)
151
+ tl.atomic_add(o + o_idx, buf, o_msk)
154
152
 
155
153
 
156
154
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -176,7 +174,7 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
176
174
  of the input and the sparsity layout of the output tensor.
177
175
 
178
176
  """
179
- # TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376
177
+ # TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
180
178
  x = torch.where(x == -0.0, torch.tensor(0.0), x)
181
179
  x = x.contiguous()
182
180
 
@@ -204,41 +202,42 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
204
202
  n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
205
203
 
206
204
 
207
- @triton_op("blksprs::row_wise_max", mutates_args={})
205
+ @triton_op("blksprs::row_wise_max_forward", mutates_args={})
208
206
  def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
209
207
  sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
210
208
  sparsity_block_size: int, n_sparse_blocks_output: int,
211
209
  flag_slice_only: bool = False) -> Tensor:
212
- output = torch.full(size=(n_sparse_blocks_output,
213
- sparsity_block_size,
214
- 1 if flag_slice_only else sparsity_block_size),
215
- fill_value=torch.finfo(x.dtype).min,
216
- device=x.device)
217
-
218
- x_b, x_r, x_c = x.size()
219
- x_b_s, x_r_s, x_c_s = stride(x)
220
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
221
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
222
- o_b, o_r, o_c = output.size()
223
- o_b_s, o_r_s, o_c_s = stride(output)
224
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
225
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
226
-
227
- triton_grid = lambda meta: [x_b,
228
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
229
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
230
-
231
- (wrap_triton(row_wise_max_kernel)[triton_grid]
232
- (x,
233
- x_b, x_b_s, x_r_s, x_c_s,
234
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
235
- output,
236
- o_b, o_b_s, o_r_s,
237
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
238
- sparsity_reverse_lut_output,
239
- sparsity_block_size))
240
-
241
- return output
210
+ with torch.no_grad():
211
+ output = torch.full(size=(n_sparse_blocks_output,
212
+ sparsity_block_size,
213
+ 1 if flag_slice_only else sparsity_block_size),
214
+ fill_value=torch.finfo(x.dtype).min,
215
+ device=x.device)
216
+
217
+ x_b, x_r, x_c = x.size()
218
+ x_b_s, x_r_s, x_c_s = stride(x)
219
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
220
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
221
+ o_b, o_r, o_c = output.size()
222
+ o_b_s, o_r_s, o_c_s = stride(output)
223
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
224
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
225
+
226
+ triton_grid = lambda meta: [x_b,
227
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
228
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
229
+
230
+ (wrap_triton(row_wise_max_kernel)[triton_grid]
231
+ (x,
232
+ x_b, x_b_s, x_r_s, x_c_s,
233
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
234
+ output,
235
+ o_b, o_b_s, o_r_s,
236
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
237
+ sparsity_reverse_lut_output,
238
+ sparsity_block_size))
239
+
240
+ return output
242
241
 
243
242
 
244
243
  # noinspection PyUnusedLocal
@@ -278,25 +277,22 @@ def row_wise_max_kernel(x,
278
277
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
279
278
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
280
279
 
281
- if rev_idx_spa == -1:
282
- tl.device_assert(False)
283
- return
284
-
285
- blk_idx = ((pid_blk * x_b_s) +
286
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
287
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
288
- blk_msk = (blk_idx >= 0 and
289
- blk_idx < x_b * x_b_s)
290
- blk = tl.load(x + blk_idx, mask=blk_msk)
280
+ if rev_idx_spa >= 0:
281
+ blk_idx = ((pid_blk * x_b_s) +
282
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
283
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
284
+ blk_msk = (blk_idx >= 0 and
285
+ blk_idx < x_b * x_b_s)
286
+ blk = tl.load(x + blk_idx, mask=blk_msk)
291
287
 
292
- buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
288
+ buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
293
289
 
294
- o_idx = (rev_idx_spa * o_b_s +
295
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
296
- (tl.arange(0, 1))[None, :])
297
- o_msk = (o_idx >= 0 and
298
- o_idx < o_b * o_b_s)
299
- tl.atomic_max(o + o_idx, buf, o_msk)
290
+ o_idx = (rev_idx_spa * o_b_s +
291
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
292
+ (tl.arange(0, 1))[None, :])
293
+ o_msk = (o_idx >= 0 and
294
+ o_idx < o_b * o_b_s)
295
+ tl.atomic_max(o + o_idx, buf, o_msk)
300
296
 
301
297
 
302
298
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -343,41 +339,43 @@ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
343
339
  return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size)
344
340
 
345
341
 
346
- @triton_op("blksprs::row_wise_add", mutates_args={})
342
+ @triton_op("blksprs::row_wise_add_forward", mutates_args={})
347
343
  def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
348
344
  sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
349
345
  y: Tensor, sparsity_block_size: int) -> Tensor:
350
- output = torch.zeros_like(x)
351
-
352
- x_b, x_r, x_c = x.size()
353
- x_b_s, x_r_s, x_c_s = stride(x)
354
- s_lut_r, s_lut_c = sparsity_lut_x.size()
355
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
356
- y_b, y_r, y_c = y.size()
357
- y_b_s, y_r_s, y_c_s = stride(y)
358
- s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
359
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
360
- o_b, o_r, o_c = output.size()
361
- o_b_s, o_r_s, o_c_s = stride(output)
362
-
363
- triton_grid = lambda meta: [o_b,
364
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
365
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
366
-
367
- (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
368
- (x,
369
- x_b, x_b_s, x_r_s, x_c_s,
370
- sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
371
- y, y_b, y_b_s, y_r_s, y_c_s,
372
- s_l_y_b, s_l_y_b_s, s_l_y_r_s,
373
- sparsity_reverse_x_lut_rwm,
374
- output,
375
- o_b, o_b_s, o_r_s, o_c_s,
376
- sparsity_block_size))
377
-
378
- return output
346
+ with torch.no_grad():
347
+ output = torch.zeros_like(x)
348
+
349
+ x_b, x_r, x_c = x.size()
350
+ x_b_s, x_r_s, x_c_s = stride(x)
351
+ s_lut_r, s_lut_c = sparsity_lut_x.size()
352
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
353
+ y_b, y_r, y_c = y.size()
354
+ y_b_s, y_r_s, y_c_s = stride(y)
355
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
356
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
357
+ o_b, o_r, o_c = output.size()
358
+ o_b_s, o_r_s, o_c_s = stride(output)
359
+
360
+ triton_grid = lambda meta: [o_b,
361
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
362
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
363
+
364
+ (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
365
+ (x,
366
+ x_b, x_b_s, x_r_s, x_c_s,
367
+ sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
368
+ y, y_b, y_b_s, y_r_s, y_c_s,
369
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
370
+ sparsity_reverse_x_lut_rwm,
371
+ output,
372
+ o_b, o_b_s, o_r_s, o_c_s,
373
+ sparsity_block_size))
374
+
375
+ return output
379
376
 
380
377
 
378
+ # noinspection PyUnusedLocal
381
379
  @triton.autotune(
382
380
  configs=get_autotune_configs(),
383
381
  key=["sparsity_block_size"],
@@ -46,14 +46,15 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
46
46
  partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
47
47
 
48
48
 
49
- @triton_op("blksprs::split", mutates_args={})
49
+ @triton_op("blksprs::split_forward", mutates_args={})
50
50
  def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
51
51
  _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
52
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
53
- n_sparse_blocks)
52
+ with torch.no_grad():
53
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
54
+ n_sparse_blocks)
54
55
 
55
56
 
56
- def split_backward(ctx, grad_output):
57
+ def split_wrapper_backward(ctx, grad_output):
57
58
  sparsity_layout = ctx.saved_tensors[0]
58
59
  num_partitions = ctx.num_partitions
59
60
  dim = ctx.dim
@@ -109,7 +110,7 @@ def split_setup_context(ctx, inputs, output):
109
110
  ctx.sparsity_block_size = sparsity_block_size
110
111
 
111
112
 
112
- split_forward.register_autograd(split_backward, setup_context=split_setup_context)
113
+ split_forward.register_autograd(split_wrapper_backward, setup_context=split_setup_context)
113
114
 
114
115
 
115
116
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -150,14 +151,15 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
150
151
  partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
151
152
 
152
153
 
153
- @triton_op("blksprs::merge", mutates_args={})
154
+ @triton_op("blksprs::merge_forward", mutates_args={})
154
155
  def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
155
156
  _: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
156
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
157
- n_sparse_blocks)
157
+ with torch.no_grad():
158
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
159
+ n_sparse_blocks)
158
160
 
159
161
 
160
- def merge_backward(ctx, grad_output):
162
+ def merge_wrapper_backward(ctx, grad_output):
161
163
  sparsity_layout = ctx.saved_tensors[0]
162
164
  num_partitions = ctx.num_partitions
163
165
  dim = ctx.dim
@@ -216,4 +218,4 @@ def merge_setup_context(ctx, inputs, output):
216
218
  ctx.sparsity_block_size = sparsity_block_size
217
219
 
218
220
 
219
- merge_forward.register_autograd(merge_backward, setup_context=merge_setup_context)
221
+ merge_forward.register_autograd(merge_wrapper_backward, setup_context=merge_setup_context)
blksprs/ops/repeat.py CHANGED
@@ -92,15 +92,16 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
92
92
  lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
93
93
 
94
94
 
95
- @triton_op("blksprs::repeat", mutates_args={})
95
+ @triton_op("blksprs::repeat_forward", mutates_args={})
96
96
  def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
97
97
  sparsity_reverse_lut: Tensor,
98
98
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
99
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
100
- n_sparse_blocks)
99
+ with torch.no_grad():
100
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
101
+ n_sparse_blocks)
101
102
 
102
103
 
103
- def repeat_backward(ctx, grad_output):
104
+ def repeat_wrapper_backward(ctx, grad_output):
104
105
  sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
105
106
  sparsity_block_size = ctx.sparsity_block_size
106
107
  n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
@@ -190,4 +191,4 @@ def repeat_setup_context(ctx, inputs, output):
190
191
  ctx.sparsity_block_size = sparsity_block_size
191
192
 
192
193
 
193
- repeat_forward.register_autograd(repeat_backward, setup_context=repeat_setup_context)
194
+ repeat_forward.register_autograd(repeat_wrapper_backward, setup_context=repeat_setup_context)