blksprs 2.0rc7__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/softmax.py CHANGED
@@ -1,3 +1,5 @@
1
+ import pdb
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
@@ -7,6 +9,7 @@ from triton import language as tl
7
9
 
8
10
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
9
11
  from blksprs.utils.blksprs_tensor import BlksprsTensor
12
+ from blksprs.utils.debugging import dbg_tensor_full
10
13
  from blksprs.utils.tools import stride
11
14
  from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
12
15
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
@@ -47,19 +50,13 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
47
50
  sparsity_block_size))
48
51
 
49
52
 
50
- @triton_op("blksprs::softmax", mutates_args={})
53
+ @triton_op("blksprs::softmax_forward", mutates_args={})
51
54
  def softmax_forward(x: Tensor, sparsity_layout: Tensor,
52
55
  sparsity_lut: Tensor,
53
56
  sparsity_reverse_lut_rws: Tensor,
54
57
  sparsity_block_size: int) -> Tensor:
55
58
  output = torch.zeros_like(x)
56
59
 
57
- x_b, x_r, x_c = x.size()
58
- x_b_s, x_r_s, x_c_s = stride(x)
59
- s_lut_r, s_lut_c = sparsity_lut.size()
60
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
61
- o_b, o_r, o_c = output.size()
62
-
63
60
  x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
64
61
  flag_slice_only=True)
65
62
  x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
@@ -67,6 +64,11 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
67
64
  x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
68
65
  flag_slice_only=True)
69
66
 
67
+ x_b, x_r, x_c = x.size()
68
+ x_b_s, x_r_s, x_c_s = stride(x)
69
+ s_lut_r, s_lut_c = sparsity_lut.size()
70
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
71
+ o_b, o_r, o_c = output.size()
70
72
  s_b, s_r, s_c = x_exp_row_wise_sum.shape
71
73
  s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
72
74
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
@@ -89,50 +91,58 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
89
91
  return output
90
92
 
91
93
 
92
- def softmax_backward(ctx, grad_output):
94
+ def softmax_backward_wrapper(ctx, grad_output):
93
95
  o, sparsity_layout, sparsity_lut = ctx.saved_tensors
94
96
  sparsity_block_size = ctx.sparsity_block_size
95
97
 
96
- s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
97
-
98
- sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
99
- sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
100
- (sparsity_layout_s_flat == 1) -
101
- (1 * (sparsity_layout_s_flat == 0)))
102
-
103
- o_b, o_r, o_c = o.size()
104
- o_b_s, o_r_s, o_c_s = stride(o)
105
- s_lut_r, s_lut_c = sparsity_lut.size()
106
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
107
- s_b, s_r, s_c = s.size()
108
- s_b_s, s_r_s, s_c_s = stride(s)
109
- s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
110
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
111
-
112
- grad_x = torch.zeros_like(o, dtype=torch.float)
113
-
114
- triton_grid = lambda meta: [o_b,
115
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
116
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
117
-
118
- # TODO wrap
119
- (softmax_kernel_grad[triton_grid]
120
- (grad_output,
121
- o_b, o_b_s, o_r_s, o_c_s,
122
- o,
123
- o_b, o_b_s, o_r_s, o_c_s,
124
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
125
- s,
126
- s_b, s_b_s, s_r_s, s_c_s,
127
- s_l_s_b, s_l_s_b_s, s_l_s_r_s,
128
- sparsity_reverse_lut_s,
129
- grad_x,
130
- o_b, o_b_s, o_r_s, o_c_s,
131
- sparsity_block_size))
132
-
133
- return grad_x, None, None, None, None, None
98
+ return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
99
+ sparsity_block_size), None, None, None, None, None
100
+
101
+
102
+ @triton_op("blksprs::softmax_backward", mutates_args={})
103
+ def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
104
+ sparsity_block_size: int) -> Tensor:
105
+ with torch.no_grad():
106
+ grad_x = torch.zeros_like(o, dtype=torch.float)
107
+
108
+ s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
109
+
110
+ sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
111
+ sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
112
+ (sparsity_layout_s_flat == 1) -
113
+ (1 * (sparsity_layout_s_flat == 0)))
114
+
115
+ o_b, o_r, o_c = o.size()
116
+ o_b_s, o_r_s, o_c_s = stride(o)
117
+ s_lut_r, s_lut_c = sparsity_lut.size()
118
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
119
+ s_b, s_r, s_c = s.size()
120
+ s_b_s, s_r_s, s_c_s = stride(s)
121
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
122
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
123
+
124
+ triton_grid = lambda meta: [o_b,
125
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
126
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
127
+
128
+ (wrap_triton(softmax_kernel_grad)[triton_grid]
129
+ (grad_output,
130
+ o_b, o_b_s, o_r_s, o_c_s,
131
+ o,
132
+ o_b, o_b_s, o_r_s, o_c_s,
133
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
134
+ s,
135
+ s_b, s_b_s, s_r_s, s_c_s,
136
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
137
+ sparsity_reverse_lut_s,
138
+ grad_x,
139
+ o_b, o_b_s, o_r_s, o_c_s,
140
+ sparsity_block_size))
141
+
142
+ return grad_x
134
143
 
135
144
 
145
+ # noinspection PyUnusedLocal
136
146
  @triton.autotune(
137
147
  configs=get_autotune_configs(),
138
148
  key=["sparsity_block_size"],
@@ -193,6 +203,7 @@ def softmax_kernel(x,
193
203
  tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
194
204
 
195
205
 
206
+ # noinspection PyUnusedLocal
196
207
  @triton.autotune(
197
208
  configs=get_autotune_configs(),
198
209
  key=["sparsity_block_size"],
@@ -293,4 +304,239 @@ def softmax_setup_context(ctx, inputs, output):
293
304
  ctx.sparsity_block_size = sparsity_block_size
294
305
 
295
306
 
296
- softmax_forward.register_autograd(softmax_backward, setup_context=softmax_setup_context)
307
+ softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
308
+
309
+
310
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
311
+ def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
312
+ lut: dict = None) -> BlksprsTensor:
313
+ """Computes the softmax fused for each row of a block-sparse tensor in compressed form.
314
+
315
+ Note:
316
+ This softmax implementation is a fused version that loads the entire row of a block-sparse tensor into memory.
317
+ See :func:`softmax` for a true block-wise softmax implementation.
318
+
319
+ Args:
320
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
321
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
322
+ sparsity_block_size (int): The size of the sparsity blocks.
323
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
324
+
325
+ Returns:
326
+ BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
327
+
328
+ """
329
+ x = x.contiguous()
330
+
331
+ validate_dimensions(x)
332
+ validate_contiguous(x)
333
+ validate_dtype_float_32(x)
334
+ validate_device(x)
335
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
336
+ validate_sparsity_block_size(sparsity_block_size, x)
337
+
338
+ lut = softmax_fused_build_lut(lut, sparsity_layout)
339
+
340
+ return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
341
+ lut["sparsity_reverse_lut"],
342
+ sparsity_block_size))
343
+
344
+
345
+ @triton_op("blksprs::softmax_fused_forward", mutates_args={})
346
+ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
347
+ sparsity_reverse_lut: Tensor,
348
+ sparsity_block_size: int) -> Tensor:
349
+ output = torch.zeros_like(x)
350
+
351
+ x_b, x_r, x_c = x.size()
352
+ x_b_s, x_r_s, x_c_s = stride(x)
353
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
354
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
355
+
356
+ triton_grid = lambda meta: [s_l_b,
357
+ s_l_r,
358
+ sparsity_block_size]
359
+
360
+ (wrap_triton(softmax_fused_kernel)[triton_grid]
361
+ (x,
362
+ x_b, x_b_s, x_r_s, x_c_s,
363
+ output,
364
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
365
+ sparsity_reverse_lut,
366
+ sparsity_block_size))
367
+
368
+ return output
369
+
370
+
371
+ def softmax_fused_backward_wrapper(ctx, grad_output):
372
+ o, sparsity_layout, sparsity_reverse_lut = ctx.saved_tensors
373
+ sparsity_block_size = ctx.sparsity_block_size
374
+
375
+ return softmax_fused_backward(grad_output, o, sparsity_reverse_lut, sparsity_layout,
376
+ sparsity_block_size), None, None, None, None, None
377
+
378
+
379
+ @triton_op("blksprs::softmax_fused_backward", mutates_args={})
380
+ def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut: Tensor, sparsity_layout: Tensor,
381
+ sparsity_block_size: int) -> Tensor:
382
+ with torch.no_grad():
383
+ grad_x = torch.zeros_like(o)
384
+
385
+ g_b, g_r, g_c = grad_output.size()
386
+ g_b_s, g_r_s, g_c_s = stride(grad_output)
387
+ o_b, o_r, o_c = o.size()
388
+ o_b_s, o_r_s, o_c_s = stride(o)
389
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
390
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
391
+
392
+ triton_grid = lambda meta: [s_l_b,
393
+ s_l_r,
394
+ sparsity_block_size]
395
+
396
+ (wrap_triton(softmax_fused_kernel_grad)[triton_grid]
397
+ (grad_output,
398
+ g_b, g_b_s, g_r_s, g_c_s,
399
+ o,
400
+ o_b, o_b_s, o_r_s, o_c_s,
401
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
402
+ sparsity_reverse_lut,
403
+ grad_x,
404
+ sparsity_block_size))
405
+
406
+ return grad_x
407
+
408
+
409
+ # noinspection PyUnusedLocal
410
+ @triton.autotune(
411
+ configs=get_autotune_configs(),
412
+ key=["sparsity_block_size"],
413
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
414
+ reset_to_zero=["o"]
415
+ )
416
+ @triton.jit
417
+ def softmax_fused_kernel(x,
418
+ x_b, x_b_s, x_r_s, x_c_s,
419
+ o,
420
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
421
+ r_lut,
422
+ sparsity_block_size: tl.constexpr,
423
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
424
+ # Get triton block indices
425
+ pid_bat = tl.program_id(axis=0)
426
+ pid_row = tl.program_id(axis=1)
427
+ pid_lin = tl.program_id(axis=2)
428
+
429
+ # Load reverse sparsity indices of row
430
+ blk_rev_idx = (pid_bat * s_l_b_s +
431
+ pid_row * s_l_r_s +
432
+ (tl.arange(0, s_l_c) * s_l_c_s))
433
+ blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
434
+ blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
435
+
436
+ if (not (tl.min(blk_rev) == -1 and
437
+ tl.max(blk_rev) == -1)):
438
+ # Extend sparsity indices to cover sparsity blocks
439
+ blk_rev_ext = tl.expand_dims(blk_rev, -1)
440
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
441
+ blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
442
+
443
+ # Load line of x
444
+ blk_x_idx = (blk_rev_ext * x_b_s +
445
+ pid_lin * x_r_s +
446
+ (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
447
+ blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
448
+ and blk_rev_ext != -1)
449
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
450
+
451
+ # Compute softmax
452
+ blk_x_softmax = tl.softmax(blk_x)
453
+
454
+ # Store output
455
+ tl.store(o + blk_x_idx, blk_x_softmax, mask=blk_x_mask)
456
+
457
+
458
+ # noinspection PyUnusedLocal
459
+ @triton.autotune(
460
+ configs=get_autotune_configs(),
461
+ key=["sparsity_block_size"],
462
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
463
+ reset_to_zero=["o"]
464
+ )
465
+ @triton.jit
466
+ def softmax_fused_kernel_grad(g,
467
+ g_b, g_b_s, g_r_s, g_c_s,
468
+ x,
469
+ x_b, x_b_s, x_r_s, x_c_s,
470
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
471
+ r_lut,
472
+ o,
473
+ sparsity_block_size: tl.constexpr,
474
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
475
+ # Get triton block indices
476
+ pid_bat = tl.program_id(axis=0)
477
+ pid_row = tl.program_id(axis=1)
478
+ pid_lin = tl.program_id(axis=2)
479
+
480
+ # Load reverse sparsity indices of row
481
+ blk_rev_idx = (pid_bat * s_l_b_s +
482
+ pid_row * s_l_r_s +
483
+ (tl.arange(0, s_l_c) * s_l_c_s))
484
+ blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
485
+ blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
486
+
487
+ if (not (tl.min(blk_rev) == -1 and
488
+ tl.max(blk_rev) == -1)):
489
+ # Extend sparsity indices to cover sparsity blocks
490
+ blk_rev_ext = tl.expand_dims(blk_rev, -1)
491
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
492
+ blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
493
+
494
+ # Load line of g
495
+ blk_g_idx = (blk_rev_ext * g_b_s +
496
+ pid_lin * g_r_s +
497
+ (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * g_c_s)
498
+ blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
499
+ and blk_rev_ext != -1)
500
+ blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
501
+
502
+ # Load line of x
503
+ blk_x_idx = (blk_rev_ext * x_b_s +
504
+ pid_lin * x_r_s +
505
+ (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
506
+ blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
507
+ and blk_rev_ext != -1)
508
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
509
+
510
+ # Compute gradients
511
+ blk_grad = blk_x * (blk_g - tl.sum(blk_x * blk_g))
512
+
513
+ tl.store(o + blk_x_idx, blk_grad, mask=blk_x_mask)
514
+
515
+
516
+ def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
517
+ if lut is None:
518
+ lut = dict()
519
+
520
+ if "sparsity_reverse_lut" not in lut:
521
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
522
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
523
+ (sparsity_layout_flat == 1) -
524
+ (1 * (sparsity_layout_flat == 0)))
525
+ .reshape(sparsity_layout.size())
526
+ .reshape(-1).contiguous())
527
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
528
+
529
+ validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut"])
530
+
531
+ return lut
532
+
533
+
534
+ # noinspection PyUnusedLocal
535
+ def softmax_fused_setup_context(ctx, inputs, output):
536
+ (_, sparsity_layout, sparsity_reverse_lut, sparsity_block_size) = inputs
537
+
538
+ ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut)
539
+ ctx.sparsity_block_size = sparsity_block_size
540
+
541
+
542
+ softmax_fused_forward.register_autograd(softmax_fused_backward_wrapper, setup_context=softmax_fused_setup_context)
blksprs/ops/transpose.py CHANGED
@@ -28,7 +28,6 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
28
28
 
29
29
  """
30
30
  x = x.contiguous()
31
- x_t = x.transpose(-1, -2).contiguous()
32
31
 
33
32
  validate_dimensions(x)
34
33
  validate_contiguous(x)
@@ -38,20 +37,22 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
38
37
 
39
38
  lut = transpose_build_lut(lut, sparsity_layout)
40
39
 
41
- return BlksprsTensor(transpose_forward(x_t, lut["sparsity_layout_t"],
40
+ return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
42
41
  lut["sparsity_lut"], lut["sparsity_reverse_lut"],
43
42
  sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
44
43
 
45
44
 
46
- @triton_op("blksprs::transpose", mutates_args={})
45
+ @triton_op("blksprs::transpose_forward", mutates_args={})
47
46
  def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
48
47
  sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
49
48
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
50
- return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
51
- sparsity_block_size, n_sparse_blocks)
49
+ with torch.no_grad():
50
+ x_t = x.transpose(-1, -2).contiguous()
51
+ return flow_pull_forward(x_t, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
52
+ sparsity_block_size, n_sparse_blocks)
52
53
 
53
54
 
54
- def transpose_backward(ctx, grad_output):
55
+ def transpose_wrapper_backward(ctx, grad_output):
55
56
  sparsity_layout = ctx.saved_tensors[0]
56
57
  sparsity_block_size = ctx.sparsity_block_size
57
58
 
@@ -96,4 +97,4 @@ def transpose_setup_context(ctx, inputs, output):
96
97
  ctx.sparsity_block_size = sparsity_block_size
97
98
 
98
99
 
99
- transpose_forward.register_autograd(transpose_backward, setup_context=transpose_setup_context)
100
+ transpose_forward.register_autograd(transpose_wrapper_backward, setup_context=transpose_setup_context)
@@ -2,15 +2,7 @@ import os
2
2
 
3
3
  blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
4
4
 
5
- if blksprs_autotune_mode == "TEST":
6
- autotune_parameters = [
7
- (16, 3, 8),
8
-
9
- (32, 3, 8),
10
-
11
- (64, 3, 8),
12
- ]
13
- elif blksprs_autotune_mode == "DEFAULT":
5
+ if blksprs_autotune_mode == "DEFAULT":
14
6
  autotune_parameters = [
15
7
  (16, 3, 8),
16
8
  (16, 4, 4),
@@ -28,6 +20,14 @@ elif blksprs_autotune_mode == "DEFAULT":
28
20
  (128, 4, 4),
29
21
  (128, 5, 2),
30
22
  ]
23
+ elif blksprs_autotune_mode == "TEST":
24
+ autotune_parameters = [
25
+ (16, 3, 8),
26
+
27
+ (32, 3, 8),
28
+
29
+ (64, 3, 8),
30
+ ]
31
31
  else:
32
32
  raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
33
33
 
@@ -75,4 +75,4 @@ def get_autotune_configs():
75
75
  autotune_configs.append(
76
76
  triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
77
77
 
78
- return autotune_configs
78
+ return autotune_configs
@@ -26,7 +26,6 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
26
26
 
27
27
  # Apply weights
28
28
  sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
29
- # TODO At the moment, manual cast is needed. Bug with custom_fwd?
30
29
  xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
31
30
  interim = xw
32
31
 
blksprs/utils/tools.py CHANGED
@@ -5,14 +5,14 @@ from torch import Tensor, Size
5
5
  torch._dynamo.config.capture_scalar_outputs = True
6
6
 
7
7
 
8
- def do_shape_blocksparse(x: Tensor):
8
+ def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
9
9
  if x.dim() == 3:
10
10
  return x.contiguous(), x.size()
11
11
 
12
12
  return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
13
13
 
14
14
 
15
- def undo_shape_blocksparse(x: Tensor, shape: Size):
15
+ def undo_shape_blocksparse(x: Tensor, shape: Size | tuple[int, ...]) -> Tensor:
16
16
  if x.shape[:-2] == shape[:-2]:
17
17
  return x
18
18
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc7
3
+ Version: 2.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -0,0 +1,23 @@
1
+ blksprs/__init__.py,sha256=o_Rj7fz_70vbMGLePihczVIVcM8E28vY3ah-d1q4ZO0,1613
2
+ blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
3
+ blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
4
+ blksprs/ops/conversion.py,sha256=kf5HKofZ4nVeHCIqQoYKiIlgsAhq33Tnmnr1c17Fkqs,21906
5
+ blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
6
+ blksprs/ops/flow.py,sha256=qdWBCLDSkKaa8CAfkO1NgH-J5N7yMsILyR7qEpyrIUU,8246
7
+ blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
8
+ blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
9
+ blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
10
+ blksprs/ops/softmax.py,sha256=H0OxST_XX1QLa7HDTDHznzibVHAxnp5sVbMU32HLxf0,21967
11
+ blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
12
+ blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
13
+ blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
14
+ blksprs/utils/autotuning.py,sha256=a-kmWRjJ3eED2XbjkQeOJSyW8bdIs27HgKMPvAKqWeU,2052
15
+ blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
16
+ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
+ blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
18
+ blksprs/utils/tools.py,sha256=CPf7viQ2OTcZFrB1aSL8_us4VE9M6YEfDz2dE30jr9I,715
19
+ blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
20
+ blksprs-2.1.dist-info/METADATA,sha256=uPVm8Y7fX5iModz6j3hNAftdtauCsJ-iYrMa-Pv3xnU,9506
21
+ blksprs-2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ blksprs-2.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,23 +0,0 @@
1
- blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
2
- blksprs/layouting/distribution_layout.py,sha256=TkMh_DYKX56Cb8Vq7EHyupMRvzm0XbUNP8QP7afv9wM,5122
3
- blksprs/layouting/sparsity_layout.py,sha256=6GOjwllDUK9L8jEQNu2i17Pp1BIIQm8fv3xVuiR0zIw,10228
4
- blksprs/ops/conversion.py,sha256=2zAdbaZ1iP2lisLVeG-k-f571G4HJapADhSwpY0Zd3o,21503
5
- blksprs/ops/distribution.py,sha256=6joac_zl3ZnRkPqLPQ0d88r7IbcrWAg0HiV93LOZw-w,20453
6
- blksprs/ops/flow.py,sha256=UO5ba5TFgVpEyT7r0hnWYw3vhRDpBOxyPHUBeNOAYPs,7935
7
- blksprs/ops/matmul.py,sha256=02hujXMtFgF7ohepM3v6h9okrfcU-J3mQZV17B-qvh0,12235
8
- blksprs/ops/partitioning.py,sha256=nAV28f3NtvT4OFvDtnE0A-VxpDQmMXS0pZw4CJwzqGA,9838
9
- blksprs/ops/repeat.py,sha256=bQpJuwtt8aRdSzxT78lJ8f8fLDhPkYK5UvMfJ-PQrkc,8977
10
- blksprs/ops/softmax.py,sha256=-NoTf1Cpuku9C99N0LuMydT_ObozWTnZJGDZxseXEXI,12209
11
- blksprs/ops/transpose.py,sha256=PQKteFnzNAOEC7voO7wh_dq9c54UjCboJz889aBCwKc,4010
12
- blksprs/ops/misc/broadcast_ops.py,sha256=DhUbliT9TBT6zlEjutBmY1EAEUPmYOt2mKQ5i46vN1c,5880
13
- blksprs/ops/misc/row_wise.py,sha256=5u_J8WOTepvf6XtZ8r0lLPofYrI5fGB7mxSmGC81IR0,19167
14
- blksprs/utils/autotuning.py,sha256=tDfMWklm2rvbo0-ahH81C3Gg0U6LHjPn3d_3pEOzmJs,2053
15
- blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
16
- blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
- blksprs/utils/processing.py,sha256=xuu9iDpwTvsqI_WKMSD8QCNuvPnfcKMRcuF2L4Zs6Ts,3808
18
- blksprs/utils/tools.py,sha256=3_2IBbd54vVU4-6m2KtAN7qjU6jeF4UfPkbjeFqMpYo,664
19
- blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
20
- blksprs-2.0rc7.dist-info/METADATA,sha256=ER9DHdVeYUZUsjE-2bEB9fePw0FVI1vknwPNrj7mDPE,9509
21
- blksprs-2.0rc7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
22
- blksprs-2.0rc7.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
- blksprs-2.0rc7.dist-info/RECORD,,