blksprs 2.0rc7__py3-none-any.whl → 2.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
+ from blksprs.utils.tools import version
1
2
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
3
 
3
4
 
@@ -1,6 +1,10 @@
1
+ import typing
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
6
+ from torch._library import triton_op
7
+ from torch._library.triton import wrap_triton
4
8
  from triton import language as tl
5
9
 
6
10
  from blksprs.utils.blksprs_tensor import BlksprsTensor
@@ -10,6 +14,7 @@ from blksprs.utils.validation import validate_dimensions, validate_device, \
10
14
  validate_contiguous
11
15
 
12
16
 
17
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
13
18
  def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
14
19
  dim: int, size_target: torch.Size,
15
20
  sparsity_block_size: int) -> Tensor:
@@ -34,32 +39,40 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
34
39
 
35
40
  adjusted_dim = dim % 3
36
41
 
37
- output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
38
- dtype=torch.bool, device=indices.device)
39
-
40
- i_b, i_r, i_c = indices.size()
41
- i_b_s, i_r_s, i_c_s = stride(indices)
42
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
43
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
44
- o_b, o_r, o_c = output.size()
45
- o_b_s, o_r_s, o_c_s = stride(output)
46
-
47
- triton_grid = lambda meta: [i_b,
48
- triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
49
- triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
50
-
51
- # TODO wrap
52
- (build_distribution_layout_kernel[triton_grid]
53
- (indices,
54
- i_b, i_b_s, i_r_s, i_c_s,
55
- sparsity_lut_i,
56
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
57
- adjusted_dim,
58
- output,
59
- o_b, o_b_s, o_r_s, o_c_s,
60
- sparsity_block_size))
61
-
62
- return output
42
+ return build_distribution_layout_operation(indices, sparsity_lut_i, adjusted_dim, size_target, sparsity_block_size)
43
+
44
+
45
+ @triton_op("blksprs::build_distribution_layout", mutates_args={})
46
+ def build_distribution_layout_operation(indices: Tensor, sparsity_lut_i: Tensor,
47
+ adjusted_dim: int, size_target: typing.List[int],
48
+ sparsity_block_size: int) -> Tensor:
49
+ with torch.no_grad():
50
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size,
51
+ size_target[2] // sparsity_block_size,
52
+ dtype=torch.bool, device=indices.device)
53
+
54
+ i_b, i_r, i_c = indices.size()
55
+ i_b_s, i_r_s, i_c_s = stride(indices)
56
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
57
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
58
+ o_b, o_r, o_c = output.size()
59
+ o_b_s, o_r_s, o_c_s = stride(output)
60
+
61
+ triton_grid = lambda meta: [i_b,
62
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
63
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
64
+
65
+ (wrap_triton(build_distribution_layout_kernel)[triton_grid]
66
+ (indices,
67
+ i_b, i_b_s, i_r_s, i_c_s,
68
+ sparsity_lut_i,
69
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
70
+ adjusted_dim,
71
+ output,
72
+ o_b, o_b_s, o_r_s, o_c_s,
73
+ sparsity_block_size))
74
+
75
+ return output
63
76
 
64
77
 
65
78
  @triton.autotune(
@@ -3,7 +3,7 @@ import math
3
3
  import torch
4
4
  import triton
5
5
  from torch import Tensor
6
- from torch._library.triton import wrap_triton
6
+ from torch._library.triton import wrap_triton, triton_op
7
7
  from triton import language as tl
8
8
 
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
@@ -29,27 +29,32 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
29
29
  validate_contiguous(x)
30
30
  validate_device(x)
31
31
 
32
- output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
33
- dtype=torch.bool, device=x.device)
32
+ return build_sparsity_layout_operation(x, sparsity_block_size)
34
33
 
35
- x_b, x_r, x_c = x.size()
36
- x_b_s, x_r_s, x_c_s = stride(x)
37
- o_b, o_r, o_c = output.size()
38
- o_b_s, o_r_s, o_c_s = stride(output)
39
34
 
40
- triton_grid = lambda meta: [x_b,
41
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
35
+ @triton_op("blksprs::build_sparsity_layout", mutates_args={})
36
+ def build_sparsity_layout_operation(x: Tensor, sparsity_block_size: int) -> Tensor:
37
+ with torch.no_grad():
38
+ output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
39
+ dtype=torch.bool, device=x.device)
43
40
 
44
- # TODO wrap
45
- (build_sparsity_layout_kernel[triton_grid]
46
- (x,
47
- x_b, x_b_s, x_r_s, x_c_s,
48
- output,
49
- o_b, o_b_s, o_r_s, o_c_s,
50
- sparsity_block_size))
41
+ x_b, x_r, x_c = x.size()
42
+ x_b_s, x_r_s, x_c_s = stride(x)
43
+ o_b, o_r, o_c = output.size()
44
+ o_b_s, o_r_s, o_c_s = stride(output)
51
45
 
52
- return output
46
+ triton_grid = lambda meta: [x_b,
47
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
48
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
49
+
50
+ (wrap_triton(build_sparsity_layout_kernel)[triton_grid]
51
+ (x,
52
+ x_b, x_b_s, x_r_s, x_c_s,
53
+ output,
54
+ o_b, o_b_s, o_r_s, o_c_s,
55
+ sparsity_block_size))
56
+
57
+ return output
53
58
 
54
59
 
55
60
  @triton.autotune(
@@ -87,6 +92,7 @@ def build_sparsity_layout_kernel(x,
87
92
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
88
93
 
89
94
 
95
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
90
96
  def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
91
97
  sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
92
98
  """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
@@ -114,33 +120,40 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
114
120
 
115
121
  validate_contiguous(sparsity_layout_from, sparsity_lut)
116
122
 
117
- o_b = sparsity_layout_from.size(0)
118
- o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
119
- o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
120
-
121
- output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
122
-
123
- x_b, x_r, x_c = x.size()
124
- x_b_s, x_r_s, x_c_s = stride(x)
125
- s_lut_r, s_lut_c = sparsity_lut.size()
126
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
127
- o_b_s, o_r_s, o_c_s = stride(output)
128
-
129
- triton_grid = lambda meta: [x_b,
130
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
131
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
132
-
133
- # TODO wrap
134
- (build_sparsity_layout_adaption_kernel[triton_grid]
135
- (x,
136
- x_b, x_b_s, x_r_s, x_c_s,
137
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
138
- output,
139
- o_b, o_b_s, o_r_s, o_c_s,
140
- sparsity_block_size_from,
141
- sparsity_block_size_to))
142
-
143
- return output
123
+ return build_sparsity_layout_adaption_operation(x, sparsity_layout_from, sparsity_lut,
124
+ sparsity_block_size_from, sparsity_block_size_to)
125
+
126
+
127
+ @triton_op("blksprs::build_sparsity_layout_adaption", mutates_args={})
128
+ def build_sparsity_layout_adaption_operation(x: Tensor, sparsity_layout_from: Tensor, sparsity_lut: Tensor,
129
+ sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
130
+ with torch.no_grad():
131
+ o_b = sparsity_layout_from.size(0)
132
+ o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
133
+ o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
134
+
135
+ output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
136
+
137
+ x_b, x_r, x_c = x.size()
138
+ x_b_s, x_r_s, x_c_s = stride(x)
139
+ s_lut_r, s_lut_c = sparsity_lut.size()
140
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
141
+ o_b_s, o_r_s, o_c_s = stride(output)
142
+
143
+ triton_grid = lambda meta: [x_b,
144
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
145
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
146
+
147
+ (wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
148
+ (x,
149
+ x_b, x_b_s, x_r_s, x_c_s,
150
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
151
+ output,
152
+ o_b, o_b_s, o_r_s, o_c_s,
153
+ sparsity_block_size_from,
154
+ sparsity_block_size_to))
155
+
156
+ return output
144
157
 
145
158
 
146
159
  @triton.autotune(
blksprs/ops/conversion.py CHANGED
@@ -52,33 +52,34 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
52
52
  lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
53
53
 
54
54
 
55
- @triton_op("blksprs::to_sparse", mutates_args={})
55
+ @triton_op("blksprs::to_sparse_forward", mutates_args={})
56
56
  def to_sparse_forward(x: Tensor, _: Tensor,
57
57
  sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
58
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
59
- dtype=x.dtype, device=x.device)
58
+ with torch.no_grad():
59
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
60
+ dtype=x.dtype, device=x.device)
60
61
 
61
- x_b, x_r, x_c = x.size()
62
- x_b_s, x_r_s, x_c_s = stride(x)
63
- s_lut_r, s_lut_c = sparsity_lut.size()
64
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
65
- o_b, o_r, o_c = output.size()
66
- o_b_s, o_r_s, o_c_s = stride(output)
62
+ x_b, x_r, x_c = x.size()
63
+ x_b_s, x_r_s, x_c_s = stride(x)
64
+ s_lut_r, s_lut_c = sparsity_lut.size()
65
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
66
+ o_b, o_r, o_c = output.size()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
67
68
 
68
- triton_grid = lambda meta: [o_b,
69
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
70
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
69
+ triton_grid = lambda meta: [o_b,
70
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
71
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
71
72
 
72
- (wrap_triton(to_sparse_kernel)[triton_grid]
73
- (x, x_b, x_b_s, x_r_s, x_c_s,
74
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
75
- output, o_b_s, o_r_s, o_c_s,
76
- sparsity_block_size))
73
+ (wrap_triton(to_sparse_kernel)[triton_grid]
74
+ (x, x_b, x_b_s, x_r_s, x_c_s,
75
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
76
+ output, o_b_s, o_r_s, o_c_s,
77
+ sparsity_block_size))
77
78
 
78
- return output
79
+ return output
79
80
 
80
81
 
81
- def to_sparse_backward(ctx, grad_output):
82
+ def to_sparse_wrapper_backward(ctx, grad_output):
82
83
  sparsity_layout = ctx.saved_tensors[0]
83
84
  sparsity_block_size = ctx.sparsity_block_size
84
85
 
@@ -161,7 +162,7 @@ def to_sparse_setup_context(ctx, inputs, output):
161
162
  ctx.sparsity_block_size = sparsity_block_size
162
163
 
163
164
 
164
- to_sparse_forward.register_autograd(to_sparse_backward, setup_context=to_sparse_setup_context)
165
+ to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to_sparse_setup_context)
165
166
 
166
167
 
167
168
  def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
@@ -207,38 +208,39 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
207
208
  lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
208
209
 
209
210
 
210
- @triton_op("blksprs::to_dense", mutates_args={})
211
+ @triton_op("blksprs::to_dense_forward", mutates_args={})
211
212
  def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
212
213
  sparsity_reverse_lut: Tensor,
213
214
  sparsity_block_size: int, fill_value: float) -> Tensor:
214
- output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
215
- sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
216
- dtype=x.dtype, device=x.device)
217
-
218
- x_b, x_r, x_c = x.shape
219
- x_b_s, x_r_s, x_c_s = stride(x)
220
- s_l_b, s_l_r, s_l_c = sparsity_layout.size()
221
- s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
222
- o_b, o_r, o_c = output.size()
223
- o_b_s, o_r_s, o_c_s = stride(output)
224
-
225
- triton_grid = lambda meta: [o_b,
226
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
227
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
228
-
229
- (wrap_triton(to_dense_kernel)[triton_grid]
230
- (x,
231
- x_b, x_b_s, x_r_s, x_c_s,
232
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
233
- sparsity_reverse_lut,
234
- output,
235
- o_b, o_b_s, o_r_s, o_c_s,
236
- sparsity_block_size))
237
-
238
- return output
239
-
240
-
241
- def to_dense_backward(ctx, grad_output):
215
+ with torch.no_grad():
216
+ output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
217
+ sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
218
+ dtype=x.dtype, device=x.device)
219
+
220
+ x_b, x_r, x_c = x.shape
221
+ x_b_s, x_r_s, x_c_s = stride(x)
222
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
223
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
224
+ o_b, o_r, o_c = output.size()
225
+ o_b_s, o_r_s, o_c_s = stride(output)
226
+
227
+ triton_grid = lambda meta: [o_b,
228
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
229
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
230
+
231
+ (wrap_triton(to_dense_kernel)[triton_grid]
232
+ (x,
233
+ x_b, x_b_s, x_r_s, x_c_s,
234
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
235
+ sparsity_reverse_lut,
236
+ output,
237
+ o_b, o_b_s, o_r_s, o_c_s,
238
+ sparsity_block_size))
239
+
240
+ return output
241
+
242
+
243
+ def to_dense_wrapper_backward(ctx, grad_output):
242
244
  sparsity_layout = ctx.saved_tensors[0]
243
245
  sparsity_block_size = ctx.sparsity_block_size
244
246
 
@@ -316,7 +318,7 @@ def to_dense_setup_context(ctx, inputs, output):
316
318
  ctx.sparsity_block_size = sparsity_block_size
317
319
 
318
320
 
319
- to_dense_forward.register_autograd(to_dense_backward, setup_context=to_dense_setup_context)
321
+ to_dense_forward.register_autograd(to_dense_wrapper_backward, setup_context=to_dense_setup_context)
320
322
 
321
323
 
322
324
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
@@ -372,45 +374,45 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
372
374
  n_sparse_blocks_to)), sparsity_layout_to
373
375
 
374
376
 
375
- @triton_op("blksprs::adapt_layout", mutates_args={})
377
+ @triton_op("blksprs::adapt_layout_forward", mutates_args={})
376
378
  def adapt_layout_forward(x: Tensor,
377
379
  sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
378
380
  sparsity_block_size_from: int,
379
381
  _: Tensor, sparsity_lut_to: Tensor,
380
382
  sparsity_block_size_to: int,
381
383
  n_sparse_blocks_to: int) -> Tensor:
382
- output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
383
- dtype=x.dtype, device=x.device)
384
-
385
- x_b, x_r, x_c = x.size()
386
- x_b_s, x_r_s, x_c_s = stride(x)
387
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
388
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
389
- o_b, o_r, o_c = output.size()
390
- o_b_s, o_r_s, o_c_s = stride(output)
391
- s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
392
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
393
-
394
- triton_grid = lambda meta: [o_b,
395
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
396
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
397
-
398
- # TODO wrap
399
- (adapt_layout_kernel[triton_grid]
400
- (x,
401
- x_b, x_b_s, x_r_s, x_c_s,
402
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
403
- sparsity_reverse_lut_from,
404
- output,
405
- o_b, o_b_s, o_r_s, o_c_s,
406
- sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
407
- sparsity_block_size_from,
408
- sparsity_block_size_to))
409
-
410
- return output
411
-
412
-
413
- def adapt_layout_backward(ctx, grad_output):
384
+ with torch.no_grad():
385
+ output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
386
+ dtype=x.dtype, device=x.device)
387
+
388
+ x_b, x_r, x_c = x.size()
389
+ x_b_s, x_r_s, x_c_s = stride(x)
390
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
391
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
392
+ o_b, o_r, o_c = output.size()
393
+ o_b_s, o_r_s, o_c_s = stride(output)
394
+ s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
395
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
396
+
397
+ triton_grid = lambda meta: [o_b,
398
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
399
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
400
+
401
+ (wrap_triton(adapt_layout_kernel)[triton_grid]
402
+ (x,
403
+ x_b, x_b_s, x_r_s, x_c_s,
404
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
405
+ sparsity_reverse_lut_from,
406
+ output,
407
+ o_b, o_b_s, o_r_s, o_c_s,
408
+ sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
409
+ sparsity_block_size_from,
410
+ sparsity_block_size_to))
411
+
412
+ return output
413
+
414
+
415
+ def adapt_layout_wrapper_backward(ctx, grad_output):
414
416
  x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
415
417
  sparsity_block_size_from = ctx.sparsity_block_size_from
416
418
  sparsity_block_size_to = ctx.sparsity_block_size_to
@@ -501,4 +503,4 @@ def adapt_layout_setup_context(ctx, inputs, output):
501
503
  ctx.sparsity_block_size_to = sparsity_block_size_to
502
504
 
503
505
 
504
- adapt_layout_forward.register_autograd(adapt_layout_backward, setup_context=adapt_layout_setup_context)
506
+ adapt_layout_forward.register_autograd(adapt_layout_wrapper_backward, setup_context=adapt_layout_setup_context)
@@ -51,44 +51,45 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
51
51
  sparsity_block_size))
52
52
 
53
53
 
54
- @triton_op("blksprs::gather", mutates_args={})
54
+ @triton_op("blksprs::gather_forward", mutates_args={})
55
55
  def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
56
56
  dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
57
57
  sparsity_block_size: int) -> Tensor:
58
- output = torch.zeros_like(i, dtype=x.dtype)
59
-
60
- x_b, x_r, x_c = x.size()
61
- x_b_s, x_r_s, x_c_s = stride(x)
62
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
63
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
64
- i_b, i_r, i_c = i.size()
65
- i_b_s, i_r_s, i_c_s = stride(i)
66
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
67
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
68
- o_b, o_r, o_c = output.size()
69
- o_b_s, o_r_s, o_c_s = stride(output)
70
-
71
- triton_grid = lambda meta: [o_b,
72
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
73
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
74
-
75
- (wrap_triton(gather_kernel)[triton_grid]
76
- (x,
77
- x_b, x_b_s, x_r_s, x_c_s,
78
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
79
- sparsity_reverse_lut_x,
80
- dim,
81
- i,
82
- i_b, i_b_s, i_r_s, i_c_s,
83
- output,
84
- o_b, o_b_s, o_r_s, o_c_s,
85
- sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
86
- sparsity_block_size))
87
-
88
- return output
89
-
90
-
91
- def gather_backward(ctx, grad_output):
58
+ with torch.no_grad():
59
+ output = torch.zeros_like(i, dtype=x.dtype)
60
+
61
+ x_b, x_r, x_c = x.size()
62
+ x_b_s, x_r_s, x_c_s = stride(x)
63
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
64
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
65
+ i_b, i_r, i_c = i.size()
66
+ i_b_s, i_r_s, i_c_s = stride(i)
67
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
68
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
69
+ o_b, o_r, o_c = output.size()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
+
72
+ triton_grid = lambda meta: [o_b,
73
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
74
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
75
+
76
+ (wrap_triton(gather_kernel)[triton_grid]
77
+ (x,
78
+ x_b, x_b_s, x_r_s, x_c_s,
79
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
80
+ sparsity_reverse_lut_x,
81
+ dim,
82
+ i,
83
+ i_b, i_b_s, i_r_s, i_c_s,
84
+ output,
85
+ o_b, o_b_s, o_r_s, o_c_s,
86
+ sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
87
+ sparsity_block_size))
88
+
89
+ return output
90
+
91
+
92
+ def gather_wrapper_backward(ctx, grad_output):
92
93
  sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
93
94
  dim = ctx.dim
94
95
  sparsity_block_size = ctx.sparsity_block_size
@@ -221,7 +222,7 @@ def gather_setup_context(ctx, inputs, output):
221
222
  ctx.sparsity_block_size = sparsity_block_size
222
223
 
223
224
 
224
- gather_forward.register_autograd(gather_backward, setup_context=gather_setup_context)
225
+ gather_forward.register_autograd(gather_wrapper_backward, setup_context=gather_setup_context)
225
226
 
226
227
 
227
228
  def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
@@ -288,52 +289,53 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
288
289
  reduce_op))
289
290
 
290
291
 
291
- @triton_op("blksprs::scatter_reduce", mutates_args={})
292
+ @triton_op("blksprs::scatter_reduce_forward", mutates_args={})
292
293
  def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
293
294
  dim: int, i: Tensor,
294
295
  sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
295
296
  sparsity_block_size: int, n_sparse_blocks: int,
296
297
  reduce_op: str) -> Tensor:
297
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
298
- dtype=x.dtype, device=x.device)
299
-
300
- x_b, x_r, x_c = x.size()
301
- x_b_s, x_r_s, x_c_s = stride(x)
302
- s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
303
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
304
- i_b, i_r, i_c = i.size()
305
- i_b_s, i_r_s, i_c_s = stride(i)
306
- o_b, o_r, o_c = output.size()
307
- o_b_s, o_r_s, o_c_s = stride(output)
308
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
309
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
310
-
311
- triton_grid = lambda meta: [x_b,
312
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
313
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
314
-
315
- reduce_op_ind = 0
316
- if reduce_op == "sum":
317
- reduce_op_ind = 1
318
-
319
- (wrap_triton(scatter_reduce_kernel)[triton_grid]
320
- (x,
321
- x_b, x_b_s, x_r_s, x_c_s,
322
- sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
323
- dim,
324
- i,
325
- i_b, i_b_s, i_r_s, i_c_s,
326
- output,
327
- o_b, o_b_s,
328
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
329
- sparsity_reverse_lut_o,
330
- reduce_op_ind,
331
- sparsity_block_size))
332
-
333
- return output
334
-
335
-
336
- def scatter_reduce_backward(ctx, grad_output):
298
+ with torch.no_grad():
299
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
300
+ dtype=x.dtype, device=x.device)
301
+
302
+ x_b, x_r, x_c = x.size()
303
+ x_b_s, x_r_s, x_c_s = stride(x)
304
+ s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
305
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
306
+ i_b, i_r, i_c = i.size()
307
+ i_b_s, i_r_s, i_c_s = stride(i)
308
+ o_b, o_r, o_c = output.size()
309
+ o_b_s, o_r_s, o_c_s = stride(output)
310
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
311
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
312
+
313
+ triton_grid = lambda meta: [x_b,
314
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
315
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
316
+
317
+ reduce_op_ind = 0
318
+ if reduce_op == "sum":
319
+ reduce_op_ind = 1
320
+
321
+ (wrap_triton(scatter_reduce_kernel)[triton_grid]
322
+ (x,
323
+ x_b, x_b_s, x_r_s, x_c_s,
324
+ sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
325
+ dim,
326
+ i,
327
+ i_b, i_b_s, i_r_s, i_c_s,
328
+ output,
329
+ o_b, o_b_s,
330
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
331
+ sparsity_reverse_lut_o,
332
+ reduce_op_ind,
333
+ sparsity_block_size))
334
+
335
+ return output
336
+
337
+
338
+ def scatter_reduce_wrapper_backward(ctx, grad_output):
337
339
  sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
338
340
  dim = ctx.dim
339
341
  sparsity_block_size = ctx.sparsity_block_size
@@ -477,4 +479,4 @@ def scatter_reduce_setup_context(ctx, inputs, output):
477
479
  ctx.reduce_op = reduce_op
478
480
 
479
481
 
480
- scatter_reduce_forward.register_autograd(scatter_reduce_backward, setup_context=scatter_reduce_setup_context)
482
+ scatter_reduce_forward.register_autograd(scatter_reduce_wrapper_backward, setup_context=scatter_reduce_setup_context)