blksprs 2.0rc7__tar.gz → 2.0rc8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {blksprs-2.0rc7 → blksprs-2.0rc8}/PKG-INFO +1 -1
  2. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/__init__.py +1 -0
  3. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/layouting/distribution_layout.py +39 -26
  4. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/layouting/sparsity_layout.py +58 -45
  5. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/conversion.py +86 -84
  6. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/distribution.py +80 -78
  7. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/flow.py +64 -60
  8. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/matmul.py +50 -55
  9. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/misc/broadcast_ops.py +28 -27
  10. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/misc/row_wise.py +123 -125
  11. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/partitioning.py +12 -10
  12. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/repeat.py +6 -5
  13. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/softmax.py +55 -47
  14. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/ops/transpose.py +8 -7
  15. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/utils/autotuning.py +10 -10
  16. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/utils/processing.py +0 -1
  17. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/utils/tools.py +8 -0
  18. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs.egg-info/PKG-INFO +1 -1
  19. {blksprs-2.0rc7 → blksprs-2.0rc8}/pyproject.toml +1 -1
  20. {blksprs-2.0rc7 → blksprs-2.0rc8}/README.md +0 -0
  21. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/utils/benchmarking.py +0 -0
  22. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/utils/blksprs_tensor.py +0 -0
  23. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs/utils/validation.py +0 -0
  24. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs.egg-info/SOURCES.txt +0 -0
  25. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs.egg-info/dependency_links.txt +0 -0
  26. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs.egg-info/requires.txt +0 -0
  27. {blksprs-2.0rc7 → blksprs-2.0rc8}/blksprs.egg-info/top_level.txt +0 -0
  28. {blksprs-2.0rc7 → blksprs-2.0rc8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc7
3
+ Version: 2.0rc8
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -1,3 +1,4 @@
1
+ from blksprs.utils.tools import version
1
2
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
3
 
3
4
 
@@ -1,6 +1,10 @@
1
+ import typing
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
6
+ from torch._library import triton_op
7
+ from torch._library.triton import wrap_triton
4
8
  from triton import language as tl
5
9
 
6
10
  from blksprs.utils.blksprs_tensor import BlksprsTensor
@@ -10,6 +14,7 @@ from blksprs.utils.validation import validate_dimensions, validate_device, \
10
14
  validate_contiguous
11
15
 
12
16
 
17
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
13
18
  def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
14
19
  dim: int, size_target: torch.Size,
15
20
  sparsity_block_size: int) -> Tensor:
@@ -34,32 +39,40 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
34
39
 
35
40
  adjusted_dim = dim % 3
36
41
 
37
- output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
38
- dtype=torch.bool, device=indices.device)
39
-
40
- i_b, i_r, i_c = indices.size()
41
- i_b_s, i_r_s, i_c_s = stride(indices)
42
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
43
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
44
- o_b, o_r, o_c = output.size()
45
- o_b_s, o_r_s, o_c_s = stride(output)
46
-
47
- triton_grid = lambda meta: [i_b,
48
- triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
49
- triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
50
-
51
- # TODO wrap
52
- (build_distribution_layout_kernel[triton_grid]
53
- (indices,
54
- i_b, i_b_s, i_r_s, i_c_s,
55
- sparsity_lut_i,
56
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
57
- adjusted_dim,
58
- output,
59
- o_b, o_b_s, o_r_s, o_c_s,
60
- sparsity_block_size))
61
-
62
- return output
42
+ return build_distribution_layout_operation(indices, sparsity_lut_i, adjusted_dim, size_target, sparsity_block_size)
43
+
44
+
45
+ @triton_op("blksprs::build_distribution_layout", mutates_args={})
46
+ def build_distribution_layout_operation(indices: Tensor, sparsity_lut_i: Tensor,
47
+ adjusted_dim: int, size_target: typing.List[int],
48
+ sparsity_block_size: int) -> Tensor:
49
+ with torch.no_grad():
50
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size,
51
+ size_target[2] // sparsity_block_size,
52
+ dtype=torch.bool, device=indices.device)
53
+
54
+ i_b, i_r, i_c = indices.size()
55
+ i_b_s, i_r_s, i_c_s = stride(indices)
56
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
57
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
58
+ o_b, o_r, o_c = output.size()
59
+ o_b_s, o_r_s, o_c_s = stride(output)
60
+
61
+ triton_grid = lambda meta: [i_b,
62
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
63
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
64
+
65
+ (wrap_triton(build_distribution_layout_kernel)[triton_grid]
66
+ (indices,
67
+ i_b, i_b_s, i_r_s, i_c_s,
68
+ sparsity_lut_i,
69
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
70
+ adjusted_dim,
71
+ output,
72
+ o_b, o_b_s, o_r_s, o_c_s,
73
+ sparsity_block_size))
74
+
75
+ return output
63
76
 
64
77
 
65
78
  @triton.autotune(
@@ -3,7 +3,7 @@ import math
3
3
  import torch
4
4
  import triton
5
5
  from torch import Tensor
6
- from torch._library.triton import wrap_triton
6
+ from torch._library.triton import wrap_triton, triton_op
7
7
  from triton import language as tl
8
8
 
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
@@ -29,27 +29,32 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
29
29
  validate_contiguous(x)
30
30
  validate_device(x)
31
31
 
32
- output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
33
- dtype=torch.bool, device=x.device)
32
+ return build_sparsity_layout_operation(x, sparsity_block_size)
34
33
 
35
- x_b, x_r, x_c = x.size()
36
- x_b_s, x_r_s, x_c_s = stride(x)
37
- o_b, o_r, o_c = output.size()
38
- o_b_s, o_r_s, o_c_s = stride(output)
39
34
 
40
- triton_grid = lambda meta: [x_b,
41
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
35
+ @triton_op("blksprs::build_sparsity_layout", mutates_args={})
36
+ def build_sparsity_layout_operation(x: Tensor, sparsity_block_size: int) -> Tensor:
37
+ with torch.no_grad():
38
+ output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
39
+ dtype=torch.bool, device=x.device)
43
40
 
44
- # TODO wrap
45
- (build_sparsity_layout_kernel[triton_grid]
46
- (x,
47
- x_b, x_b_s, x_r_s, x_c_s,
48
- output,
49
- o_b, o_b_s, o_r_s, o_c_s,
50
- sparsity_block_size))
41
+ x_b, x_r, x_c = x.size()
42
+ x_b_s, x_r_s, x_c_s = stride(x)
43
+ o_b, o_r, o_c = output.size()
44
+ o_b_s, o_r_s, o_c_s = stride(output)
51
45
 
52
- return output
46
+ triton_grid = lambda meta: [x_b,
47
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
48
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
49
+
50
+ (wrap_triton(build_sparsity_layout_kernel)[triton_grid]
51
+ (x,
52
+ x_b, x_b_s, x_r_s, x_c_s,
53
+ output,
54
+ o_b, o_b_s, o_r_s, o_c_s,
55
+ sparsity_block_size))
56
+
57
+ return output
53
58
 
54
59
 
55
60
  @triton.autotune(
@@ -87,6 +92,7 @@ def build_sparsity_layout_kernel(x,
87
92
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
88
93
 
89
94
 
95
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
90
96
  def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
91
97
  sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
92
98
  """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
@@ -114,33 +120,40 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
114
120
 
115
121
  validate_contiguous(sparsity_layout_from, sparsity_lut)
116
122
 
117
- o_b = sparsity_layout_from.size(0)
118
- o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
119
- o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
120
-
121
- output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
122
-
123
- x_b, x_r, x_c = x.size()
124
- x_b_s, x_r_s, x_c_s = stride(x)
125
- s_lut_r, s_lut_c = sparsity_lut.size()
126
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
127
- o_b_s, o_r_s, o_c_s = stride(output)
128
-
129
- triton_grid = lambda meta: [x_b,
130
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
131
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
132
-
133
- # TODO wrap
134
- (build_sparsity_layout_adaption_kernel[triton_grid]
135
- (x,
136
- x_b, x_b_s, x_r_s, x_c_s,
137
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
138
- output,
139
- o_b, o_b_s, o_r_s, o_c_s,
140
- sparsity_block_size_from,
141
- sparsity_block_size_to))
142
-
143
- return output
123
+ return build_sparsity_layout_adaption_operation(x, sparsity_layout_from, sparsity_lut,
124
+ sparsity_block_size_from, sparsity_block_size_to)
125
+
126
+
127
+ @triton_op("blksprs::build_sparsity_layout_adaption", mutates_args={})
128
+ def build_sparsity_layout_adaption_operation(x: Tensor, sparsity_layout_from: Tensor, sparsity_lut: Tensor,
129
+ sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
130
+ with torch.no_grad():
131
+ o_b = sparsity_layout_from.size(0)
132
+ o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
133
+ o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
134
+
135
+ output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
136
+
137
+ x_b, x_r, x_c = x.size()
138
+ x_b_s, x_r_s, x_c_s = stride(x)
139
+ s_lut_r, s_lut_c = sparsity_lut.size()
140
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
141
+ o_b_s, o_r_s, o_c_s = stride(output)
142
+
143
+ triton_grid = lambda meta: [x_b,
144
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
145
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
146
+
147
+ (wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
148
+ (x,
149
+ x_b, x_b_s, x_r_s, x_c_s,
150
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
151
+ output,
152
+ o_b, o_b_s, o_r_s, o_c_s,
153
+ sparsity_block_size_from,
154
+ sparsity_block_size_to))
155
+
156
+ return output
144
157
 
145
158
 
146
159
  @triton.autotune(
@@ -52,33 +52,34 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
52
52
  lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
53
53
 
54
54
 
55
- @triton_op("blksprs::to_sparse", mutates_args={})
55
+ @triton_op("blksprs::to_sparse_forward", mutates_args={})
56
56
  def to_sparse_forward(x: Tensor, _: Tensor,
57
57
  sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
58
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
59
- dtype=x.dtype, device=x.device)
58
+ with torch.no_grad():
59
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
60
+ dtype=x.dtype, device=x.device)
60
61
 
61
- x_b, x_r, x_c = x.size()
62
- x_b_s, x_r_s, x_c_s = stride(x)
63
- s_lut_r, s_lut_c = sparsity_lut.size()
64
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
65
- o_b, o_r, o_c = output.size()
66
- o_b_s, o_r_s, o_c_s = stride(output)
62
+ x_b, x_r, x_c = x.size()
63
+ x_b_s, x_r_s, x_c_s = stride(x)
64
+ s_lut_r, s_lut_c = sparsity_lut.size()
65
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
66
+ o_b, o_r, o_c = output.size()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
67
68
 
68
- triton_grid = lambda meta: [o_b,
69
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
70
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
69
+ triton_grid = lambda meta: [o_b,
70
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
71
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
71
72
 
72
- (wrap_triton(to_sparse_kernel)[triton_grid]
73
- (x, x_b, x_b_s, x_r_s, x_c_s,
74
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
75
- output, o_b_s, o_r_s, o_c_s,
76
- sparsity_block_size))
73
+ (wrap_triton(to_sparse_kernel)[triton_grid]
74
+ (x, x_b, x_b_s, x_r_s, x_c_s,
75
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
76
+ output, o_b_s, o_r_s, o_c_s,
77
+ sparsity_block_size))
77
78
 
78
- return output
79
+ return output
79
80
 
80
81
 
81
- def to_sparse_backward(ctx, grad_output):
82
+ def to_sparse_wrapper_backward(ctx, grad_output):
82
83
  sparsity_layout = ctx.saved_tensors[0]
83
84
  sparsity_block_size = ctx.sparsity_block_size
84
85
 
@@ -161,7 +162,7 @@ def to_sparse_setup_context(ctx, inputs, output):
161
162
  ctx.sparsity_block_size = sparsity_block_size
162
163
 
163
164
 
164
- to_sparse_forward.register_autograd(to_sparse_backward, setup_context=to_sparse_setup_context)
165
+ to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to_sparse_setup_context)
165
166
 
166
167
 
167
168
  def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
@@ -207,38 +208,39 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
207
208
  lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
208
209
 
209
210
 
210
- @triton_op("blksprs::to_dense", mutates_args={})
211
+ @triton_op("blksprs::to_dense_forward", mutates_args={})
211
212
  def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
212
213
  sparsity_reverse_lut: Tensor,
213
214
  sparsity_block_size: int, fill_value: float) -> Tensor:
214
- output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
215
- sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
216
- dtype=x.dtype, device=x.device)
217
-
218
- x_b, x_r, x_c = x.shape
219
- x_b_s, x_r_s, x_c_s = stride(x)
220
- s_l_b, s_l_r, s_l_c = sparsity_layout.size()
221
- s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
222
- o_b, o_r, o_c = output.size()
223
- o_b_s, o_r_s, o_c_s = stride(output)
224
-
225
- triton_grid = lambda meta: [o_b,
226
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
227
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
228
-
229
- (wrap_triton(to_dense_kernel)[triton_grid]
230
- (x,
231
- x_b, x_b_s, x_r_s, x_c_s,
232
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
233
- sparsity_reverse_lut,
234
- output,
235
- o_b, o_b_s, o_r_s, o_c_s,
236
- sparsity_block_size))
237
-
238
- return output
239
-
240
-
241
- def to_dense_backward(ctx, grad_output):
215
+ with torch.no_grad():
216
+ output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
217
+ sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
218
+ dtype=x.dtype, device=x.device)
219
+
220
+ x_b, x_r, x_c = x.shape
221
+ x_b_s, x_r_s, x_c_s = stride(x)
222
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
223
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
224
+ o_b, o_r, o_c = output.size()
225
+ o_b_s, o_r_s, o_c_s = stride(output)
226
+
227
+ triton_grid = lambda meta: [o_b,
228
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
229
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
230
+
231
+ (wrap_triton(to_dense_kernel)[triton_grid]
232
+ (x,
233
+ x_b, x_b_s, x_r_s, x_c_s,
234
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
235
+ sparsity_reverse_lut,
236
+ output,
237
+ o_b, o_b_s, o_r_s, o_c_s,
238
+ sparsity_block_size))
239
+
240
+ return output
241
+
242
+
243
+ def to_dense_wrapper_backward(ctx, grad_output):
242
244
  sparsity_layout = ctx.saved_tensors[0]
243
245
  sparsity_block_size = ctx.sparsity_block_size
244
246
 
@@ -316,7 +318,7 @@ def to_dense_setup_context(ctx, inputs, output):
316
318
  ctx.sparsity_block_size = sparsity_block_size
317
319
 
318
320
 
319
- to_dense_forward.register_autograd(to_dense_backward, setup_context=to_dense_setup_context)
321
+ to_dense_forward.register_autograd(to_dense_wrapper_backward, setup_context=to_dense_setup_context)
320
322
 
321
323
 
322
324
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -372,45 +374,45 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
372
374
  n_sparse_blocks_to)), sparsity_layout_to
373
375
 
374
376
 
375
- @triton_op("blksprs::adapt_layout", mutates_args={})
377
+ @triton_op("blksprs::adapt_layout_forward", mutates_args={})
376
378
  def adapt_layout_forward(x: Tensor,
377
379
  sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
378
380
  sparsity_block_size_from: int,
379
381
  _: Tensor, sparsity_lut_to: Tensor,
380
382
  sparsity_block_size_to: int,
381
383
  n_sparse_blocks_to: int) -> Tensor:
382
- output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
383
- dtype=x.dtype, device=x.device)
384
-
385
- x_b, x_r, x_c = x.size()
386
- x_b_s, x_r_s, x_c_s = stride(x)
387
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
388
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
389
- o_b, o_r, o_c = output.size()
390
- o_b_s, o_r_s, o_c_s = stride(output)
391
- s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
392
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
393
-
394
- triton_grid = lambda meta: [o_b,
395
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
396
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
397
-
398
- # TODO wrap
399
- (adapt_layout_kernel[triton_grid]
400
- (x,
401
- x_b, x_b_s, x_r_s, x_c_s,
402
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
403
- sparsity_reverse_lut_from,
404
- output,
405
- o_b, o_b_s, o_r_s, o_c_s,
406
- sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
407
- sparsity_block_size_from,
408
- sparsity_block_size_to))
409
-
410
- return output
411
-
412
-
413
- def adapt_layout_backward(ctx, grad_output):
384
+ with torch.no_grad():
385
+ output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
386
+ dtype=x.dtype, device=x.device)
387
+
388
+ x_b, x_r, x_c = x.size()
389
+ x_b_s, x_r_s, x_c_s = stride(x)
390
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
391
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
392
+ o_b, o_r, o_c = output.size()
393
+ o_b_s, o_r_s, o_c_s = stride(output)
394
+ s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
395
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
396
+
397
+ triton_grid = lambda meta: [o_b,
398
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
399
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
400
+
401
+ (wrap_triton(adapt_layout_kernel)[triton_grid]
402
+ (x,
403
+ x_b, x_b_s, x_r_s, x_c_s,
404
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
405
+ sparsity_reverse_lut_from,
406
+ output,
407
+ o_b, o_b_s, o_r_s, o_c_s,
408
+ sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
409
+ sparsity_block_size_from,
410
+ sparsity_block_size_to))
411
+
412
+ return output
413
+
414
+
415
+ def adapt_layout_wrapper_backward(ctx, grad_output):
414
416
  x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
415
417
  sparsity_block_size_from = ctx.sparsity_block_size_from
416
418
  sparsity_block_size_to = ctx.sparsity_block_size_to
@@ -501,4 +503,4 @@ def adapt_layout_setup_context(ctx, inputs, output):
501
503
  ctx.sparsity_block_size_to = sparsity_block_size_to
502
504
 
503
505
 
504
- adapt_layout_forward.register_autograd(adapt_layout_backward, setup_context=adapt_layout_setup_context)
506
+ adapt_layout_forward.register_autograd(adapt_layout_wrapper_backward, setup_context=adapt_layout_setup_context)