blksprs 2.0rc7__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,44 +51,45 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
51
51
  sparsity_block_size))
52
52
 
53
53
 
54
- @triton_op("blksprs::gather", mutates_args={})
54
+ @triton_op("blksprs::gather_forward", mutates_args={})
55
55
  def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
56
56
  dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
57
57
  sparsity_block_size: int) -> Tensor:
58
- output = torch.zeros_like(i, dtype=x.dtype)
59
-
60
- x_b, x_r, x_c = x.size()
61
- x_b_s, x_r_s, x_c_s = stride(x)
62
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
63
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
64
- i_b, i_r, i_c = i.size()
65
- i_b_s, i_r_s, i_c_s = stride(i)
66
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
67
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
68
- o_b, o_r, o_c = output.size()
69
- o_b_s, o_r_s, o_c_s = stride(output)
70
-
71
- triton_grid = lambda meta: [o_b,
72
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
73
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
74
-
75
- (wrap_triton(gather_kernel)[triton_grid]
76
- (x,
77
- x_b, x_b_s, x_r_s, x_c_s,
78
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
79
- sparsity_reverse_lut_x,
80
- dim,
81
- i,
82
- i_b, i_b_s, i_r_s, i_c_s,
83
- output,
84
- o_b, o_b_s, o_r_s, o_c_s,
85
- sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
86
- sparsity_block_size))
87
-
88
- return output
89
-
90
-
91
- def gather_backward(ctx, grad_output):
58
+ with torch.no_grad():
59
+ output = torch.zeros_like(i, dtype=x.dtype)
60
+
61
+ x_b, x_r, x_c = x.size()
62
+ x_b_s, x_r_s, x_c_s = stride(x)
63
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
64
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
65
+ i_b, i_r, i_c = i.size()
66
+ i_b_s, i_r_s, i_c_s = stride(i)
67
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
68
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
69
+ o_b, o_r, o_c = output.size()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
+
72
+ triton_grid = lambda meta: [o_b,
73
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
74
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
75
+
76
+ (wrap_triton(gather_kernel)[triton_grid]
77
+ (x,
78
+ x_b, x_b_s, x_r_s, x_c_s,
79
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
80
+ sparsity_reverse_lut_x,
81
+ dim,
82
+ i,
83
+ i_b, i_b_s, i_r_s, i_c_s,
84
+ output,
85
+ o_b, o_b_s, o_r_s, o_c_s,
86
+ sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
87
+ sparsity_block_size))
88
+
89
+ return output
90
+
91
+
92
+ def gather_wrapper_backward(ctx, grad_output):
92
93
  sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
93
94
  dim = ctx.dim
94
95
  sparsity_block_size = ctx.sparsity_block_size
@@ -221,7 +222,7 @@ def gather_setup_context(ctx, inputs, output):
221
222
  ctx.sparsity_block_size = sparsity_block_size
222
223
 
223
224
 
224
- gather_forward.register_autograd(gather_backward, setup_context=gather_setup_context)
225
+ gather_forward.register_autograd(gather_wrapper_backward, setup_context=gather_setup_context)
225
226
 
226
227
 
227
228
  def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
@@ -288,52 +289,53 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
288
289
  reduce_op))
289
290
 
290
291
 
291
- @triton_op("blksprs::scatter_reduce", mutates_args={})
292
+ @triton_op("blksprs::scatter_reduce_forward", mutates_args={})
292
293
  def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
293
294
  dim: int, i: Tensor,
294
295
  sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
295
296
  sparsity_block_size: int, n_sparse_blocks: int,
296
297
  reduce_op: str) -> Tensor:
297
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
298
- dtype=x.dtype, device=x.device)
299
-
300
- x_b, x_r, x_c = x.size()
301
- x_b_s, x_r_s, x_c_s = stride(x)
302
- s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
303
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
304
- i_b, i_r, i_c = i.size()
305
- i_b_s, i_r_s, i_c_s = stride(i)
306
- o_b, o_r, o_c = output.size()
307
- o_b_s, o_r_s, o_c_s = stride(output)
308
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
309
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
310
-
311
- triton_grid = lambda meta: [x_b,
312
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
313
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
314
-
315
- reduce_op_ind = 0
316
- if reduce_op == "sum":
317
- reduce_op_ind = 1
318
-
319
- (wrap_triton(scatter_reduce_kernel)[triton_grid]
320
- (x,
321
- x_b, x_b_s, x_r_s, x_c_s,
322
- sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
323
- dim,
324
- i,
325
- i_b, i_b_s, i_r_s, i_c_s,
326
- output,
327
- o_b, o_b_s,
328
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
329
- sparsity_reverse_lut_o,
330
- reduce_op_ind,
331
- sparsity_block_size))
332
-
333
- return output
334
-
335
-
336
- def scatter_reduce_backward(ctx, grad_output):
298
+ with torch.no_grad():
299
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
300
+ dtype=x.dtype, device=x.device)
301
+
302
+ x_b, x_r, x_c = x.size()
303
+ x_b_s, x_r_s, x_c_s = stride(x)
304
+ s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
305
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
306
+ i_b, i_r, i_c = i.size()
307
+ i_b_s, i_r_s, i_c_s = stride(i)
308
+ o_b, o_r, o_c = output.size()
309
+ o_b_s, o_r_s, o_c_s = stride(output)
310
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
311
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
312
+
313
+ triton_grid = lambda meta: [x_b,
314
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
315
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
316
+
317
+ reduce_op_ind = 0
318
+ if reduce_op == "sum":
319
+ reduce_op_ind = 1
320
+
321
+ (wrap_triton(scatter_reduce_kernel)[triton_grid]
322
+ (x,
323
+ x_b, x_b_s, x_r_s, x_c_s,
324
+ sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
325
+ dim,
326
+ i,
327
+ i_b, i_b_s, i_r_s, i_c_s,
328
+ output,
329
+ o_b, o_b_s,
330
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
331
+ sparsity_reverse_lut_o,
332
+ reduce_op_ind,
333
+ sparsity_block_size))
334
+
335
+ return output
336
+
337
+
338
+ def scatter_reduce_wrapper_backward(ctx, grad_output):
337
339
  sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
338
340
  dim = ctx.dim
339
341
  sparsity_block_size = ctx.sparsity_block_size
@@ -477,4 +479,4 @@ def scatter_reduce_setup_context(ctx, inputs, output):
477
479
  ctx.reduce_op = reduce_op
478
480
 
479
481
 
480
- scatter_reduce_forward.register_autograd(scatter_reduce_backward, setup_context=scatter_reduce_setup_context)
482
+ scatter_reduce_forward.register_autograd(scatter_reduce_wrapper_backward, setup_context=scatter_reduce_setup_context)
blksprs/ops/flow.py CHANGED
@@ -9,39 +9,41 @@ from blksprs.utils.tools import stride
9
9
  from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
10
 
11
11
 
12
- @triton_op("blksprs::flow_pull", mutates_args={})
12
+ @triton_op("blksprs::flow_pull_forward", mutates_args={})
13
13
  def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
14
14
  sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
15
15
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
16
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
17
- dtype=x.dtype, device=x.device)
18
-
19
- x_b, x_r, x_c = x.size()
20
- x_b_s, x_r_s, x_c_s = stride(x)
21
- o_b, o_r, o_c = output.size()
22
- o_b_s, o_r_s, o_c_s = stride(output)
23
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
24
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
25
- s_lut_r, s_lut_c = sparsity_lut.size()
26
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
27
-
28
- triton_grid = lambda meta: [o_b,
29
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
30
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
31
-
32
- (wrap_triton(flow_pull_kernel)[triton_grid]
33
- (x,
34
- x_b, x_b_s, x_r_s, x_c_s,
35
- output,
36
- o_b, o_b_s, o_r_s, o_c_s,
37
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
38
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
39
- sparsity_reverse_lut,
40
- sparsity_block_size))
41
-
42
- return output
43
-
44
-
16
+ with torch.no_grad():
17
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
18
+ dtype=x.dtype, device=x.device)
19
+
20
+ x_b, x_r, x_c = x.size()
21
+ x_b_s, x_r_s, x_c_s = stride(x)
22
+ o_b, o_r, o_c = output.size()
23
+ o_b_s, o_r_s, o_c_s = stride(output)
24
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
25
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
26
+ s_lut_r, s_lut_c = sparsity_lut.size()
27
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
28
+
29
+ triton_grid = lambda meta: [o_b,
30
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
31
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
32
+
33
+ (wrap_triton(flow_pull_kernel)[triton_grid]
34
+ (x,
35
+ x_b, x_b_s, x_r_s, x_c_s,
36
+ output,
37
+ o_b, o_b_s, o_r_s, o_c_s,
38
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
39
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
40
+ sparsity_reverse_lut,
41
+ sparsity_block_size))
42
+
43
+ return output
44
+
45
+
46
+ # noinspection PyUnusedLocal
45
47
  @triton.autotune(
46
48
  configs=get_autotune_configs(),
47
49
  key=["sparsity_block_size"],
@@ -76,7 +78,7 @@ def flow_pull_kernel(x,
76
78
  spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
77
79
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
78
80
 
79
- # Get reverse sparsity index
81
+ # Load reverse sparsity index
80
82
  rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
81
83
  spa_row * s_l_o_r_s +
82
84
  spa_col * s_l_o_c_s)
@@ -99,38 +101,40 @@ def flow_pull_kernel(x,
99
101
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
100
102
 
101
103
 
102
- @triton_op("blksprs::flow_push", mutates_args={})
104
+ @triton_op("blksprs::flow_push_forward", mutates_args={})
103
105
  def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
104
106
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
105
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
106
- dtype=x.dtype, device=x.device)
107
-
108
- x_b, x_r, x_c = x.size()
109
- x_b_s, x_r_s, x_c_s = stride(x)
110
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
111
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
112
- s_lut_r, s_lut_c = sparsity_lut.size()
113
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
114
- o_b, o_r, o_c = output.size()
115
- o_b_s, o_r_s, o_c_s = stride(output)
116
-
117
- triton_grid = lambda meta: [x_b,
118
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
119
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
120
-
121
- (wrap_triton(flow_push_kernel)[triton_grid]
122
- (x,
123
- x_b, x_b_s, x_r_s, x_c_s,
124
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
125
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
126
- sparsity_reverse_lut,
127
- output,
128
- o_b, o_b_s, o_r_s, o_c_s,
129
- sparsity_block_size))
130
-
131
- return output
132
-
133
-
107
+ with torch.no_grad():
108
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
109
+ dtype=x.dtype, device=x.device)
110
+
111
+ x_b, x_r, x_c = x.size()
112
+ x_b_s, x_r_s, x_c_s = stride(x)
113
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
114
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
115
+ s_lut_r, s_lut_c = sparsity_lut.size()
116
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
117
+ o_b, o_r, o_c = output.size()
118
+ o_b_s, o_r_s, o_c_s = stride(output)
119
+
120
+ triton_grid = lambda meta: [x_b,
121
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
122
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
123
+
124
+ (wrap_triton(flow_push_kernel)[triton_grid]
125
+ (x,
126
+ x_b, x_b_s, x_r_s, x_c_s,
127
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
128
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
129
+ sparsity_reverse_lut,
130
+ output,
131
+ o_b, o_b_s, o_r_s, o_c_s,
132
+ sparsity_block_size))
133
+
134
+ return output
135
+
136
+
137
+ # noinspection PyUnusedLocal
134
138
  @triton.autotune(
135
139
  configs=get_autotune_configs(),
136
140
  key=["sparsity_block_size"],
blksprs/ops/matmul.py CHANGED
@@ -55,53 +55,54 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
55
55
  sparsity_block_size, lut["n_sparse_blocks"]))
56
56
 
57
57
 
58
- @triton_op("blksprs::matmul", mutates_args={})
58
+ @triton_op("blksprs::matmul_forward", mutates_args={})
59
59
  def matmul_forward(x: Tensor, y: Tensor,
60
60
  sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
61
61
  sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
62
62
  _: Tensor, sparsity_lut_o: Tensor,
63
63
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
64
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
65
- dtype=x.dtype, device=x.device)
66
-
67
- x_b, x_r, x_c = x.size()
68
- x_b_s, x_r_s, x_c_s = stride(x)
69
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
70
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
71
- y_b, y_r, y_c = y.size()
72
- y_b_s, y_r_s, y_c_s = stride(y)
73
- s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
74
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
75
- o_b, o_r, o_c = output.size()
76
- o_b_s, o_r_s, o_c_s = stride(output)
77
- s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
78
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
79
-
80
- triton_grid = lambda meta: [o_b,
81
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
82
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
83
-
84
- (wrap_triton(matmul_kernel)[triton_grid]
85
- (x,
86
- x_b, x_b_s, x_r_s, x_c_s,
87
- s_l_x_b, s_l_x_b_s, s_l_x_r_s,
88
- s_l_x_c, s_l_x_c_s,
89
- sparsity_reverse_lut_x,
90
- y,
91
- y_b, y_b_s, y_r_s, y_c_s,
92
- s_l_y_b, s_l_y_b_s, s_l_y_r_s,
93
- s_l_y_c_s,
94
- sparsity_reverse_lut_y,
95
- output,
96
- o_b, o_b_s, o_r_s, o_c_s,
97
- sparsity_lut_o,
98
- s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
99
- sparsity_block_size))
100
-
101
- return output
102
-
103
-
104
- def matmul_backward(ctx, grad_output):
64
+ with torch.no_grad():
65
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
66
+ dtype=x.dtype, device=x.device)
67
+
68
+ x_b, x_r, x_c = x.size()
69
+ x_b_s, x_r_s, x_c_s = stride(x)
70
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
71
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
72
+ y_b, y_r, y_c = y.size()
73
+ y_b_s, y_r_s, y_c_s = stride(y)
74
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
75
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
76
+ o_b, o_r, o_c = output.size()
77
+ o_b_s, o_r_s, o_c_s = stride(output)
78
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
79
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
80
+
81
+ triton_grid = lambda meta: [o_b,
82
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
83
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
84
+
85
+ (wrap_triton(matmul_kernel)[triton_grid]
86
+ (x,
87
+ x_b, x_b_s, x_r_s, x_c_s,
88
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s,
89
+ s_l_x_c, s_l_x_c_s,
90
+ sparsity_reverse_lut_x,
91
+ y,
92
+ y_b, y_b_s, y_r_s, y_c_s,
93
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
94
+ s_l_y_c_s,
95
+ sparsity_reverse_lut_y,
96
+ output,
97
+ o_b, o_b_s, o_r_s, o_c_s,
98
+ sparsity_lut_o,
99
+ s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
100
+ sparsity_block_size))
101
+
102
+ return output
103
+
104
+
105
+ def matmul_wrapper_backward(ctx, grad_output):
105
106
  x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
106
107
  sparsity_block_size = ctx.sparsity_block_size
107
108
 
@@ -187,20 +188,16 @@ def matmul_kernel(x,
187
188
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
188
189
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
189
190
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
190
- blk_x_msk = ((blk_x_idx >= 0 and
191
- blk_x_idx < x_b * x_b_s) and
192
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
193
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
191
+ blk_x_msk = (blk_x_idx >= 0 and
192
+ blk_x_idx < x_b * x_b_s)
194
193
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
195
194
 
196
195
  blk_y_idx = ((rev_idx_spa_y * y_b_s) +
197
196
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
198
197
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
199
198
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
200
- blk_y_msk = ((blk_y_idx >= 0 and
201
- blk_y_idx < y_b * y_b_s) and
202
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
203
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
199
+ blk_y_msk = (blk_y_idx >= 0 and
200
+ blk_y_idx < y_b * y_b_s)
204
201
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
205
202
 
206
203
  # Perform matrix multiplication
@@ -213,10 +210,8 @@ def matmul_kernel(x,
213
210
  blk_o_idx = ((pid_blk * o_b_s) +
214
211
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
215
212
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
216
- blk_o_msk = ((blk_o_idx >= 0 and
217
- blk_o_idx < o_b * o_b_s) and
218
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
219
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
213
+ blk_o_msk = (blk_o_idx >= 0 and
214
+ blk_o_idx < o_b * o_b_s)
220
215
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
221
216
 
222
217
 
@@ -262,4 +257,4 @@ def matmul_setup_context(ctx, inputs, output):
262
257
  ctx.sparsity_block_size = sparsity_block_size
263
258
 
264
259
 
265
- matmul_forward.register_autograd(matmul_backward, setup_context=matmul_setup_context)
260
+ matmul_forward.register_autograd(matmul_wrapper_backward, setup_context=matmul_setup_context)
@@ -55,36 +55,37 @@ def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
55
55
  return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size)
56
56
 
57
57
 
58
- @triton_op("blksprs::broadcast_add", mutates_args={})
58
+ @triton_op("blksprs::broadcast_add_forward", mutates_args={})
59
59
  def broadcast_add_forward(x: Tensor, y: Tensor,
60
60
  sparsity_lut_o: Tensor,
61
61
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
62
- output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
63
-
64
- x_b, x_c = x.size()
65
- x_b_s, x_c_s = stride(x)
66
- y_b, y_c = y.size()
67
- y_b_s, y_c_s = stride(y)
68
- o_b, o_r, o_c = output.size()
69
- o_b_s, o_r_s, o_c_s = stride(output)
70
- s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
71
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
72
-
73
- triton_grid = lambda meta: [o_b,
74
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
75
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
76
-
77
- (wrap_triton(broadcast_add_kernel)[triton_grid]
78
- (x,
79
- x_b, x_b_s, x_c_s,
80
- y,
81
- y_b, y_b_s, y_c_s,
82
- output,
83
- o_b, o_b_s, o_r_s, o_c_s,
84
- sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
85
- sparsity_block_size))
86
-
87
- return BlksprsTensor(output)
62
+ with torch.no_grad():
63
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
64
+
65
+ x_b, x_c = x.size()
66
+ x_b_s, x_c_s = stride(x)
67
+ y_b, y_c = y.size()
68
+ y_b_s, y_c_s = stride(y)
69
+ o_b, o_r, o_c = output.size()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
72
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
73
+
74
+ triton_grid = lambda meta: [o_b,
75
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
76
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
77
+
78
+ (wrap_triton(broadcast_add_kernel)[triton_grid]
79
+ (x,
80
+ x_b, x_b_s, x_c_s,
81
+ y,
82
+ y_b, y_b_s, y_c_s,
83
+ output,
84
+ o_b, o_b_s, o_r_s, o_c_s,
85
+ sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
86
+ sparsity_block_size))
87
+
88
+ return output
88
89
 
89
90
 
90
91
  @triton.autotune(