blksprs 2.0rc8__tar.gz → 2.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {blksprs-2.0rc8 → blksprs-2.1.1}/PKG-INFO +3 -2
  2. {blksprs-2.0rc8 → blksprs-2.1.1}/README.md +1 -0
  3. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/__init__.py +3 -2
  4. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/conversion.py +4 -4
  5. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/flow.py +1 -1
  6. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/softmax.py +252 -3
  7. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/tools.py +2 -10
  8. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/PKG-INFO +3 -2
  9. {blksprs-2.0rc8 → blksprs-2.1.1}/pyproject.toml +2 -2
  10. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/layouting/distribution_layout.py +0 -0
  11. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/layouting/sparsity_layout.py +0 -0
  12. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/distribution.py +0 -0
  13. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/matmul.py +0 -0
  14. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/misc/broadcast_ops.py +0 -0
  15. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/misc/row_wise.py +0 -0
  16. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/partitioning.py +0 -0
  17. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/repeat.py +0 -0
  18. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/transpose.py +0 -0
  19. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/autotuning.py +0 -0
  20. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/benchmarking.py +0 -0
  21. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/blksprs_tensor.py +0 -0
  22. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/processing.py +0 -0
  23. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/validation.py +0 -0
  24. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/SOURCES.txt +0 -0
  25. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/dependency_links.txt +0 -0
  26. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/requires.txt +0 -0
  27. {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/top_level.txt +0 -0
  28. {blksprs-2.0rc8 → blksprs-2.1.1}/setup.cfg +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc8
4
- Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
3
+ Version: 2.1.1
4
+ Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
@@ -197,6 +197,7 @@ def test_readme():
197
197
  # Other available functions
198
198
  bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
199
199
  bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
200
+ bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
200
201
  bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
201
202
  bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
202
203
 
@@ -178,6 +178,7 @@ def test_readme():
178
178
  # Other available functions
179
179
  bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
180
180
  bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
181
+ bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
181
182
  bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
182
183
  bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
183
184
 
@@ -1,12 +1,13 @@
1
- from blksprs.utils.tools import version
2
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
3
2
 
3
+ __version__ = "2.1.1"
4
+
4
5
 
5
6
  class ops:
6
7
  from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
7
8
  from blksprs.ops.distribution import gather, scatter, scatter_reduce
8
9
  from blksprs.ops.matmul import matmul
9
- from blksprs.ops.softmax import softmax
10
+ from blksprs.ops.softmax import softmax, softmax_fused
10
11
  from blksprs.ops.transpose import transpose
11
12
  from blksprs.ops.repeat import repeat, repeat_interleave
12
13
  from blksprs.ops.partitioning import split, merge
@@ -13,7 +13,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
13
13
 
14
14
 
15
15
  def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
16
- """Wrapper for ``to_sparse``.
16
+ """Wrapper for :func:`to_sparse`.
17
17
 
18
18
  """
19
19
  return to_sparse(x, sparsity_layout, sparsity_block_size)
@@ -167,7 +167,7 @@ to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to
167
167
 
168
168
  def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
169
169
  sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
170
- """Wrapper for ``to_dense``.
170
+ """Wrapper for :func:`to_dense`.
171
171
 
172
172
  """
173
173
  return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
@@ -204,8 +204,8 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
204
204
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
205
205
  return x
206
206
 
207
- return to_dense_forward(x, sparsity_layout,
208
- lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
207
+ return Tensor(to_dense_forward(x, sparsity_layout,
208
+ lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
209
209
 
210
210
 
211
211
  @triton_op("blksprs::to_dense_forward", mutates_args={})
@@ -78,7 +78,7 @@ def flow_pull_kernel(x,
78
78
  spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
79
79
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
80
80
 
81
- # Get reverse sparsity index
81
+ # Load reverse sparsity index
82
82
  rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
83
83
  spa_row * s_l_o_r_s +
84
84
  spa_col * s_l_o_c_s)
@@ -1,3 +1,5 @@
1
+ import pdb
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
@@ -13,8 +15,20 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
13
15
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
14
16
 
15
17
 
18
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, flag_fused: bool = True,
19
+ lut: dict = None) -> BlksprsTensor:
20
+ """Wrapper for :func:`softmax_regular` and :func:`softmax_fused` based on the ``flag_fused`` parameter.
21
+
22
+ """
23
+ if flag_fused:
24
+ return softmax_fused(x, sparsity_layout, sparsity_block_size, lut)
25
+ else:
26
+ return softmax_regular(x, sparsity_layout, sparsity_block_size, lut)
27
+
28
+
16
29
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
17
- def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
30
+ def softmax_regular(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
31
+ lut: dict = None) -> BlksprsTensor:
18
32
  """Computes the softmax of a block-sparse tensor in compressed form.
19
33
 
20
34
  Note:
@@ -100,6 +114,8 @@ def softmax_backward_wrapper(ctx, grad_output):
100
114
  def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
101
115
  sparsity_block_size: int) -> Tensor:
102
116
  with torch.no_grad():
117
+ grad_x = torch.zeros_like(o, dtype=torch.float)
118
+
103
119
  s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
104
120
 
105
121
  sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
@@ -116,8 +132,6 @@ def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, spars
116
132
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
117
133
  s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
118
134
 
119
- grad_x = torch.zeros_like(o, dtype=torch.float)
120
-
121
135
  triton_grid = lambda meta: [o_b,
122
136
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
123
137
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
@@ -302,3 +316,238 @@ def softmax_setup_context(ctx, inputs, output):
302
316
 
303
317
 
304
318
  softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
319
+
320
+
321
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
322
+ def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
323
+ lut: dict = None) -> BlksprsTensor:
324
+ """Computes the softmax fused for each row of a block-sparse tensor in compressed form.
325
+
326
+ Note:
327
+ This softmax implementation is a fused version that loads the entire row of a block-sparse tensor into memory.
328
+ See :func:`softmax` for a true block-wise softmax implementation.
329
+
330
+ Args:
331
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
332
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
333
+ sparsity_block_size (int): The size of the sparsity blocks.
334
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
335
+
336
+ Returns:
337
+ BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
338
+
339
+ """
340
+ x = x.contiguous()
341
+
342
+ validate_dimensions(x)
343
+ validate_contiguous(x)
344
+ validate_dtype_float_32(x)
345
+ validate_device(x)
346
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
347
+ validate_sparsity_block_size(sparsity_block_size, x)
348
+
349
+ lut = softmax_fused_build_lut(lut, sparsity_layout)
350
+
351
+ return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
352
+ lut["sparsity_reverse_lut"],
353
+ sparsity_block_size))
354
+
355
+
356
+ @triton_op("blksprs::softmax_fused_forward", mutates_args={})
357
+ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
358
+ sparsity_reverse_lut: Tensor,
359
+ sparsity_block_size: int) -> Tensor:
360
+ output = torch.zeros_like(x)
361
+
362
+ x_b, x_r, x_c = x.size()
363
+ x_b_s, x_r_s, x_c_s = stride(x)
364
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
365
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
366
+
367
+ triton_grid = lambda meta: [s_l_b,
368
+ s_l_r,
369
+ sparsity_block_size]
370
+
371
+ (wrap_triton(softmax_fused_kernel)[triton_grid]
372
+ (x,
373
+ x_b, x_b_s, x_r_s, x_c_s,
374
+ output,
375
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
376
+ sparsity_reverse_lut,
377
+ sparsity_block_size))
378
+
379
+ return output
380
+
381
+
382
+ def softmax_fused_backward_wrapper(ctx, grad_output):
383
+ o, sparsity_layout, sparsity_reverse_lut = ctx.saved_tensors
384
+ sparsity_block_size = ctx.sparsity_block_size
385
+
386
+ return softmax_fused_backward(grad_output, o, sparsity_reverse_lut, sparsity_layout,
387
+ sparsity_block_size), None, None, None, None, None
388
+
389
+
390
+ @triton_op("blksprs::softmax_fused_backward", mutates_args={})
391
+ def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut: Tensor, sparsity_layout: Tensor,
392
+ sparsity_block_size: int) -> Tensor:
393
+ with torch.no_grad():
394
+ grad_x = torch.zeros_like(o)
395
+
396
+ g_b, g_r, g_c = grad_output.size()
397
+ g_b_s, g_r_s, g_c_s = stride(grad_output)
398
+ o_b, o_r, o_c = o.size()
399
+ o_b_s, o_r_s, o_c_s = stride(o)
400
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
401
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
402
+
403
+ triton_grid = lambda meta: [s_l_b,
404
+ s_l_r,
405
+ sparsity_block_size]
406
+
407
+ (wrap_triton(softmax_fused_kernel_grad)[triton_grid]
408
+ (grad_output,
409
+ g_b, g_b_s, g_r_s, g_c_s,
410
+ o,
411
+ o_b, o_b_s, o_r_s, o_c_s,
412
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
413
+ sparsity_reverse_lut,
414
+ grad_x,
415
+ sparsity_block_size))
416
+
417
+ return grad_x
418
+
419
+
420
+ # noinspection PyUnusedLocal
421
+ @triton.autotune(
422
+ configs=get_autotune_configs(),
423
+ key=["sparsity_block_size"],
424
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
425
+ reset_to_zero=["o"]
426
+ )
427
+ @triton.jit
428
+ def softmax_fused_kernel(x,
429
+ x_b, x_b_s, x_r_s, x_c_s,
430
+ o,
431
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
432
+ r_lut,
433
+ sparsity_block_size: tl.constexpr,
434
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
435
+ # Get triton block indices
436
+ pid_bat = tl.program_id(axis=0)
437
+ pid_row = tl.program_id(axis=1)
438
+ pid_lin = tl.program_id(axis=2)
439
+
440
+ # Load reverse sparsity indices of row
441
+ blk_rev_idx = (pid_bat * s_l_b_s +
442
+ pid_row * s_l_r_s +
443
+ (tl.arange(0, s_l_c) * s_l_c_s))
444
+ blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
445
+ blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
446
+
447
+ if (not (tl.min(blk_rev) == -1 and
448
+ tl.max(blk_rev) == -1)):
449
+ # Extend sparsity indices to cover sparsity blocks
450
+ blk_rev_ext = tl.expand_dims(blk_rev, -1)
451
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
452
+ blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
453
+
454
+ # Load line of x
455
+ blk_x_idx = (blk_rev_ext * x_b_s +
456
+ pid_lin * x_r_s +
457
+ (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
458
+ blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
459
+ and blk_rev_ext != -1)
460
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
461
+
462
+ # Compute softmax
463
+ blk_x_softmax = tl.softmax(blk_x)
464
+
465
+ # Store output
466
+ tl.store(o + blk_x_idx, blk_x_softmax, mask=blk_x_mask)
467
+
468
+
469
+ # noinspection PyUnusedLocal
470
+ @triton.autotune(
471
+ configs=get_autotune_configs(),
472
+ key=["sparsity_block_size"],
473
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
474
+ reset_to_zero=["o"]
475
+ )
476
+ @triton.jit
477
+ def softmax_fused_kernel_grad(g,
478
+ g_b, g_b_s, g_r_s, g_c_s,
479
+ x,
480
+ x_b, x_b_s, x_r_s, x_c_s,
481
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
482
+ r_lut,
483
+ o,
484
+ sparsity_block_size: tl.constexpr,
485
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
486
+ # Get triton block indices
487
+ pid_bat = tl.program_id(axis=0)
488
+ pid_row = tl.program_id(axis=1)
489
+ pid_lin = tl.program_id(axis=2)
490
+
491
+ # Load reverse sparsity indices of row
492
+ blk_rev_idx = (pid_bat * s_l_b_s +
493
+ pid_row * s_l_r_s +
494
+ (tl.arange(0, s_l_c) * s_l_c_s))
495
+ blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
496
+ blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
497
+
498
+ if (not (tl.min(blk_rev) == -1 and
499
+ tl.max(blk_rev) == -1)):
500
+ # Extend sparsity indices to cover sparsity blocks
501
+ blk_rev_ext = tl.expand_dims(blk_rev, -1)
502
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
503
+ blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
504
+
505
+ # Load line of g
506
+ blk_g_idx = (blk_rev_ext * g_b_s +
507
+ pid_lin * g_r_s +
508
+ (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * g_c_s)
509
+ blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
510
+ and blk_rev_ext != -1)
511
+ blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
512
+
513
+ # Load line of x
514
+ blk_x_idx = (blk_rev_ext * x_b_s +
515
+ pid_lin * x_r_s +
516
+ (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
517
+ blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
518
+ and blk_rev_ext != -1)
519
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
520
+
521
+ # Compute gradients
522
+ blk_grad = blk_x * (blk_g - tl.sum(blk_x * blk_g))
523
+
524
+ tl.store(o + blk_x_idx, blk_grad, mask=blk_x_mask)
525
+
526
+
527
+ def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
528
+ if lut is None:
529
+ lut = dict()
530
+
531
+ if "sparsity_reverse_lut" not in lut:
532
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
533
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
534
+ (sparsity_layout_flat == 1) -
535
+ (1 * (sparsity_layout_flat == 0)))
536
+ .reshape(sparsity_layout.size())
537
+ .reshape(-1).contiguous())
538
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
539
+
540
+ validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut"])
541
+
542
+ return lut
543
+
544
+
545
+ # noinspection PyUnusedLocal
546
+ def softmax_fused_setup_context(ctx, inputs, output):
547
+ (_, sparsity_layout, sparsity_reverse_lut, sparsity_block_size) = inputs
548
+
549
+ ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut)
550
+ ctx.sparsity_block_size = sparsity_block_size
551
+
552
+
553
+ softmax_fused_forward.register_autograd(softmax_fused_backward_wrapper, setup_context=softmax_fused_setup_context)
@@ -1,6 +1,3 @@
1
- import tomllib
2
- from pathlib import Path
3
-
4
1
  import torch
5
2
  from torch import Tensor, Size
6
3
 
@@ -8,19 +5,14 @@ from torch import Tensor, Size
8
5
  torch._dynamo.config.capture_scalar_outputs = True
9
6
 
10
7
 
11
- def version():
12
- with open(Path(__file__).parent.parent.parent.joinpath("pyproject.toml"), "rb") as f:
13
- return tomllib.load(f)["project"]["version"]
14
-
15
-
16
- def do_shape_blocksparse(x: Tensor):
8
+ def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
17
9
  if x.dim() == 3:
18
10
  return x.contiguous(), x.size()
19
11
 
20
12
  return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
21
13
 
22
14
 
23
- def undo_shape_blocksparse(x: Tensor, shape: Size):
15
+ def undo_shape_blocksparse(x: Tensor, shape: Size | tuple[int, ...]) -> Tensor:
24
16
  if x.shape[:-2] == shape[:-2]:
25
17
  return x
26
18
 
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc8
4
- Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
3
+ Version: 2.1.1
4
+ Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
@@ -197,6 +197,7 @@ def test_readme():
197
197
  # Other available functions
198
198
  bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
199
199
  bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
200
+ bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
200
201
  bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
201
202
  bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
202
203
 
@@ -1,8 +1,8 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "2.0-rc.8"
3
+ version = "2.1.1"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
- description = "A lightweight library for operations on blocksparse matrices in PyTorch."
5
+ description = "A lightweight library for operations on block-sparse matrices in PyTorch."
6
6
  readme = "README.md"
7
7
  requires-python = ">=3.11"
8
8
  license = { file = "LICENSE.md" }
File without changes
File without changes
File without changes