blksprs 2.0rc6__tar.gz → 2.0rc8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {blksprs-2.0rc6 → blksprs-2.0rc8}/PKG-INFO +7 -3
  2. {blksprs-2.0rc6 → blksprs-2.0rc8}/README.md +6 -2
  3. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/__init__.py +1 -0
  4. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/layouting/distribution_layout.py +39 -26
  5. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/layouting/sparsity_layout.py +58 -45
  6. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/conversion.py +86 -84
  7. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/distribution.py +81 -79
  8. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/flow.py +64 -60
  9. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/matmul.py +50 -55
  10. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/misc/broadcast_ops.py +29 -27
  11. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/misc/row_wise.py +134 -132
  12. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/partitioning.py +12 -10
  13. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/repeat.py +6 -5
  14. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/softmax.py +55 -47
  15. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/ops/transpose.py +8 -7
  16. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/utils/autotuning.py +10 -10
  17. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/utils/processing.py +0 -1
  18. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/utils/tools.py +8 -9
  19. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs.egg-info/PKG-INFO +7 -3
  20. {blksprs-2.0rc6 → blksprs-2.0rc8}/pyproject.toml +1 -1
  21. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/utils/benchmarking.py +0 -0
  22. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/utils/blksprs_tensor.py +0 -0
  23. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs/utils/validation.py +0 -0
  24. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs.egg-info/SOURCES.txt +0 -0
  25. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs.egg-info/dependency_links.txt +0 -0
  26. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs.egg-info/requires.txt +0 -0
  27. {blksprs-2.0rc6 → blksprs-2.0rc8}/blksprs.egg-info/top_level.txt +0 -0
  28. {blksprs-2.0rc6 → blksprs-2.0rc8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc6
3
+ Version: 2.0rc8
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -108,12 +108,16 @@ library.
108
108
 
109
109
  ## Known Limitations and Issues
110
110
 
111
+ - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
112
+ In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
113
+ performance.
114
+ Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
111
115
  - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
112
116
  which could impact graph compilation.
113
117
  - There seem to be some issues with autocasting, forcing some operations to manually cast.
114
118
  - There will be some slight numerical differences between vanilla and blksprs operations.
115
- These instabilities are due to Triton and thus cannot be fixed by this library alone.
116
- However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
119
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
120
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
117
121
 
118
122
  ## Usage
119
123
 
@@ -89,12 +89,16 @@ library.
89
89
 
90
90
  ## Known Limitations and Issues
91
91
 
92
+ - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
93
+ In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
94
+ performance.
95
+ Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
92
96
  - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
93
97
  which could impact graph compilation.
94
98
  - There seem to be some issues with autocasting, forcing some operations to manually cast.
95
99
  - There will be some slight numerical differences between vanilla and blksprs operations.
96
- These instabilities are due to Triton and thus cannot be fixed by this library alone.
97
- However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
100
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
101
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
98
102
 
99
103
  ## Usage
100
104
 
@@ -1,3 +1,4 @@
1
+ from blksprs.utils.tools import version
1
2
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
3
 
3
4
 
@@ -1,6 +1,10 @@
1
+ import typing
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
6
+ from torch._library import triton_op
7
+ from torch._library.triton import wrap_triton
4
8
  from triton import language as tl
5
9
 
6
10
  from blksprs.utils.blksprs_tensor import BlksprsTensor
@@ -10,6 +14,7 @@ from blksprs.utils.validation import validate_dimensions, validate_device, \
10
14
  validate_contiguous
11
15
 
12
16
 
17
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
13
18
  def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
14
19
  dim: int, size_target: torch.Size,
15
20
  sparsity_block_size: int) -> Tensor:
@@ -34,32 +39,40 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
34
39
 
35
40
  adjusted_dim = dim % 3
36
41
 
37
- output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
38
- dtype=torch.bool, device=indices.device)
39
-
40
- i_b, i_r, i_c = indices.size()
41
- i_b_s, i_r_s, i_c_s = stride(indices)
42
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
43
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
44
- o_b, o_r, o_c = output.size()
45
- o_b_s, o_r_s, o_c_s = stride(output)
46
-
47
- triton_grid = lambda meta: [i_b,
48
- triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
49
- triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
50
-
51
- # TODO wrap
52
- (build_distribution_layout_kernel[triton_grid]
53
- (indices,
54
- i_b, i_b_s, i_r_s, i_c_s,
55
- sparsity_lut_i,
56
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
57
- adjusted_dim,
58
- output,
59
- o_b, o_b_s, o_r_s, o_c_s,
60
- sparsity_block_size))
61
-
62
- return output
42
+ return build_distribution_layout_operation(indices, sparsity_lut_i, adjusted_dim, size_target, sparsity_block_size)
43
+
44
+
45
+ @triton_op("blksprs::build_distribution_layout", mutates_args={})
46
+ def build_distribution_layout_operation(indices: Tensor, sparsity_lut_i: Tensor,
47
+ adjusted_dim: int, size_target: typing.List[int],
48
+ sparsity_block_size: int) -> Tensor:
49
+ with torch.no_grad():
50
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size,
51
+ size_target[2] // sparsity_block_size,
52
+ dtype=torch.bool, device=indices.device)
53
+
54
+ i_b, i_r, i_c = indices.size()
55
+ i_b_s, i_r_s, i_c_s = stride(indices)
56
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
57
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
58
+ o_b, o_r, o_c = output.size()
59
+ o_b_s, o_r_s, o_c_s = stride(output)
60
+
61
+ triton_grid = lambda meta: [i_b,
62
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
63
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
64
+
65
+ (wrap_triton(build_distribution_layout_kernel)[triton_grid]
66
+ (indices,
67
+ i_b, i_b_s, i_r_s, i_c_s,
68
+ sparsity_lut_i,
69
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
70
+ adjusted_dim,
71
+ output,
72
+ o_b, o_b_s, o_r_s, o_c_s,
73
+ sparsity_block_size))
74
+
75
+ return output
63
76
 
64
77
 
65
78
  @triton.autotune(
@@ -3,7 +3,7 @@ import math
3
3
  import torch
4
4
  import triton
5
5
  from torch import Tensor
6
- from torch._library.triton import wrap_triton
6
+ from torch._library.triton import wrap_triton, triton_op
7
7
  from triton import language as tl
8
8
 
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
@@ -29,27 +29,32 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
29
29
  validate_contiguous(x)
30
30
  validate_device(x)
31
31
 
32
- output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
33
- dtype=torch.bool, device=x.device)
32
+ return build_sparsity_layout_operation(x, sparsity_block_size)
34
33
 
35
- x_b, x_r, x_c = x.size()
36
- x_b_s, x_r_s, x_c_s = stride(x)
37
- o_b, o_r, o_c = output.size()
38
- o_b_s, o_r_s, o_c_s = stride(output)
39
34
 
40
- triton_grid = lambda meta: [x_b,
41
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
35
+ @triton_op("blksprs::build_sparsity_layout", mutates_args={})
36
+ def build_sparsity_layout_operation(x: Tensor, sparsity_block_size: int) -> Tensor:
37
+ with torch.no_grad():
38
+ output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
39
+ dtype=torch.bool, device=x.device)
43
40
 
44
- # TODO wrap
45
- (build_sparsity_layout_kernel[triton_grid]
46
- (x,
47
- x_b, x_b_s, x_r_s, x_c_s,
48
- output,
49
- o_b, o_b_s, o_r_s, o_c_s,
50
- sparsity_block_size))
41
+ x_b, x_r, x_c = x.size()
42
+ x_b_s, x_r_s, x_c_s = stride(x)
43
+ o_b, o_r, o_c = output.size()
44
+ o_b_s, o_r_s, o_c_s = stride(output)
51
45
 
52
- return output
46
+ triton_grid = lambda meta: [x_b,
47
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
48
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
49
+
50
+ (wrap_triton(build_sparsity_layout_kernel)[triton_grid]
51
+ (x,
52
+ x_b, x_b_s, x_r_s, x_c_s,
53
+ output,
54
+ o_b, o_b_s, o_r_s, o_c_s,
55
+ sparsity_block_size))
56
+
57
+ return output
53
58
 
54
59
 
55
60
  @triton.autotune(
@@ -87,6 +92,7 @@ def build_sparsity_layout_kernel(x,
87
92
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
88
93
 
89
94
 
95
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
90
96
  def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
91
97
  sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
92
98
  """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
@@ -114,33 +120,40 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
114
120
 
115
121
  validate_contiguous(sparsity_layout_from, sparsity_lut)
116
122
 
117
- o_b = sparsity_layout_from.size(0)
118
- o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
119
- o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
120
-
121
- output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
122
-
123
- x_b, x_r, x_c = x.size()
124
- x_b_s, x_r_s, x_c_s = stride(x)
125
- s_lut_r, s_lut_c = sparsity_lut.size()
126
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
127
- o_b_s, o_r_s, o_c_s = stride(output)
128
-
129
- triton_grid = lambda meta: [x_b,
130
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
131
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
132
-
133
- # TODO wrap
134
- (build_sparsity_layout_adaption_kernel[triton_grid]
135
- (x,
136
- x_b, x_b_s, x_r_s, x_c_s,
137
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
138
- output,
139
- o_b, o_b_s, o_r_s, o_c_s,
140
- sparsity_block_size_from,
141
- sparsity_block_size_to))
142
-
143
- return output
123
+ return build_sparsity_layout_adaption_operation(x, sparsity_layout_from, sparsity_lut,
124
+ sparsity_block_size_from, sparsity_block_size_to)
125
+
126
+
127
+ @triton_op("blksprs::build_sparsity_layout_adaption", mutates_args={})
128
+ def build_sparsity_layout_adaption_operation(x: Tensor, sparsity_layout_from: Tensor, sparsity_lut: Tensor,
129
+ sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
130
+ with torch.no_grad():
131
+ o_b = sparsity_layout_from.size(0)
132
+ o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
133
+ o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
134
+
135
+ output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
136
+
137
+ x_b, x_r, x_c = x.size()
138
+ x_b_s, x_r_s, x_c_s = stride(x)
139
+ s_lut_r, s_lut_c = sparsity_lut.size()
140
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
141
+ o_b_s, o_r_s, o_c_s = stride(output)
142
+
143
+ triton_grid = lambda meta: [x_b,
144
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
145
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
146
+
147
+ (wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
148
+ (x,
149
+ x_b, x_b_s, x_r_s, x_c_s,
150
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
151
+ output,
152
+ o_b, o_b_s, o_r_s, o_c_s,
153
+ sparsity_block_size_from,
154
+ sparsity_block_size_to))
155
+
156
+ return output
144
157
 
145
158
 
146
159
  @triton.autotune(
@@ -52,33 +52,34 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
52
52
  lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
53
53
 
54
54
 
55
- @triton_op("blksprs::to_sparse", mutates_args={})
55
+ @triton_op("blksprs::to_sparse_forward", mutates_args={})
56
56
  def to_sparse_forward(x: Tensor, _: Tensor,
57
57
  sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
58
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
59
- dtype=x.dtype, device=x.device)
58
+ with torch.no_grad():
59
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
60
+ dtype=x.dtype, device=x.device)
60
61
 
61
- x_b, x_r, x_c = x.size()
62
- x_b_s, x_r_s, x_c_s = stride(x)
63
- s_lut_r, s_lut_c = sparsity_lut.size()
64
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
65
- o_b, o_r, o_c = output.size()
66
- o_b_s, o_r_s, o_c_s = stride(output)
62
+ x_b, x_r, x_c = x.size()
63
+ x_b_s, x_r_s, x_c_s = stride(x)
64
+ s_lut_r, s_lut_c = sparsity_lut.size()
65
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
66
+ o_b, o_r, o_c = output.size()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
67
68
 
68
- triton_grid = lambda meta: [o_b,
69
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
70
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
69
+ triton_grid = lambda meta: [o_b,
70
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
71
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
71
72
 
72
- (wrap_triton(to_sparse_kernel)[triton_grid]
73
- (x, x_b, x_b_s, x_r_s, x_c_s,
74
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
75
- output, o_b_s, o_r_s, o_c_s,
76
- sparsity_block_size))
73
+ (wrap_triton(to_sparse_kernel)[triton_grid]
74
+ (x, x_b, x_b_s, x_r_s, x_c_s,
75
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
76
+ output, o_b_s, o_r_s, o_c_s,
77
+ sparsity_block_size))
77
78
 
78
- return output
79
+ return output
79
80
 
80
81
 
81
- def to_sparse_backward(ctx, grad_output):
82
+ def to_sparse_wrapper_backward(ctx, grad_output):
82
83
  sparsity_layout = ctx.saved_tensors[0]
83
84
  sparsity_block_size = ctx.sparsity_block_size
84
85
 
@@ -161,7 +162,7 @@ def to_sparse_setup_context(ctx, inputs, output):
161
162
  ctx.sparsity_block_size = sparsity_block_size
162
163
 
163
164
 
164
- to_sparse_forward.register_autograd(to_sparse_backward, setup_context=to_sparse_setup_context)
165
+ to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to_sparse_setup_context)
165
166
 
166
167
 
167
168
  def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
@@ -207,38 +208,39 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
207
208
  lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
208
209
 
209
210
 
210
- @triton_op("blksprs::to_dense", mutates_args={})
211
+ @triton_op("blksprs::to_dense_forward", mutates_args={})
211
212
  def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
212
213
  sparsity_reverse_lut: Tensor,
213
214
  sparsity_block_size: int, fill_value: float) -> Tensor:
214
- output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
215
- sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
216
- dtype=x.dtype, device=x.device)
217
-
218
- x_b, x_r, x_c = x.shape
219
- x_b_s, x_r_s, x_c_s = stride(x)
220
- s_l_b, s_l_r, s_l_c = sparsity_layout.size()
221
- s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
222
- o_b, o_r, o_c = output.size()
223
- o_b_s, o_r_s, o_c_s = stride(output)
224
-
225
- triton_grid = lambda meta: [o_b,
226
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
227
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
228
-
229
- (wrap_triton(to_dense_kernel)[triton_grid]
230
- (x,
231
- x_b, x_b_s, x_r_s, x_c_s,
232
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
233
- sparsity_reverse_lut,
234
- output,
235
- o_b, o_b_s, o_r_s, o_c_s,
236
- sparsity_block_size))
237
-
238
- return output
239
-
240
-
241
- def to_dense_backward(ctx, grad_output):
215
+ with torch.no_grad():
216
+ output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
217
+ sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
218
+ dtype=x.dtype, device=x.device)
219
+
220
+ x_b, x_r, x_c = x.shape
221
+ x_b_s, x_r_s, x_c_s = stride(x)
222
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
223
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
224
+ o_b, o_r, o_c = output.size()
225
+ o_b_s, o_r_s, o_c_s = stride(output)
226
+
227
+ triton_grid = lambda meta: [o_b,
228
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
229
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
230
+
231
+ (wrap_triton(to_dense_kernel)[triton_grid]
232
+ (x,
233
+ x_b, x_b_s, x_r_s, x_c_s,
234
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
235
+ sparsity_reverse_lut,
236
+ output,
237
+ o_b, o_b_s, o_r_s, o_c_s,
238
+ sparsity_block_size))
239
+
240
+ return output
241
+
242
+
243
+ def to_dense_wrapper_backward(ctx, grad_output):
242
244
  sparsity_layout = ctx.saved_tensors[0]
243
245
  sparsity_block_size = ctx.sparsity_block_size
244
246
 
@@ -316,7 +318,7 @@ def to_dense_setup_context(ctx, inputs, output):
316
318
  ctx.sparsity_block_size = sparsity_block_size
317
319
 
318
320
 
319
- to_dense_forward.register_autograd(to_dense_backward, setup_context=to_dense_setup_context)
321
+ to_dense_forward.register_autograd(to_dense_wrapper_backward, setup_context=to_dense_setup_context)
320
322
 
321
323
 
322
324
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -372,45 +374,45 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
372
374
  n_sparse_blocks_to)), sparsity_layout_to
373
375
 
374
376
 
375
- @triton_op("blksprs::adapt_layout", mutates_args={})
377
+ @triton_op("blksprs::adapt_layout_forward", mutates_args={})
376
378
  def adapt_layout_forward(x: Tensor,
377
379
  sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
378
380
  sparsity_block_size_from: int,
379
381
  _: Tensor, sparsity_lut_to: Tensor,
380
382
  sparsity_block_size_to: int,
381
383
  n_sparse_blocks_to: int) -> Tensor:
382
- output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
383
- dtype=x.dtype, device=x.device)
384
-
385
- x_b, x_r, x_c = x.size()
386
- x_b_s, x_r_s, x_c_s = stride(x)
387
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
388
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
389
- o_b, o_r, o_c = output.size()
390
- o_b_s, o_r_s, o_c_s = stride(output)
391
- s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
392
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
393
-
394
- triton_grid = lambda meta: [o_b,
395
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
396
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
397
-
398
- # TODO wrap
399
- (adapt_layout_kernel[triton_grid]
400
- (x,
401
- x_b, x_b_s, x_r_s, x_c_s,
402
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
403
- sparsity_reverse_lut_from,
404
- output,
405
- o_b, o_b_s, o_r_s, o_c_s,
406
- sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
407
- sparsity_block_size_from,
408
- sparsity_block_size_to))
409
-
410
- return output
411
-
412
-
413
- def adapt_layout_backward(ctx, grad_output):
384
+ with torch.no_grad():
385
+ output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
386
+ dtype=x.dtype, device=x.device)
387
+
388
+ x_b, x_r, x_c = x.size()
389
+ x_b_s, x_r_s, x_c_s = stride(x)
390
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
391
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
392
+ o_b, o_r, o_c = output.size()
393
+ o_b_s, o_r_s, o_c_s = stride(output)
394
+ s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
395
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
396
+
397
+ triton_grid = lambda meta: [o_b,
398
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
399
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
400
+
401
+ (wrap_triton(adapt_layout_kernel)[triton_grid]
402
+ (x,
403
+ x_b, x_b_s, x_r_s, x_c_s,
404
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
405
+ sparsity_reverse_lut_from,
406
+ output,
407
+ o_b, o_b_s, o_r_s, o_c_s,
408
+ sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
409
+ sparsity_block_size_from,
410
+ sparsity_block_size_to))
411
+
412
+ return output
413
+
414
+
415
+ def adapt_layout_wrapper_backward(ctx, grad_output):
414
416
  x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
415
417
  sparsity_block_size_from = ctx.sparsity_block_size_from
416
418
  sparsity_block_size_to = ctx.sparsity_block_size_to
@@ -501,4 +503,4 @@ def adapt_layout_setup_context(ctx, inputs, output):
501
503
  ctx.sparsity_block_size_to = sparsity_block_size_to
502
504
 
503
505
 
504
- adapt_layout_forward.register_autograd(adapt_layout_backward, setup_context=adapt_layout_setup_context)
506
+ adapt_layout_forward.register_autograd(adapt_layout_wrapper_backward, setup_context=adapt_layout_setup_context)