blksprs 2.0rc7__py3-none-any.whl → 2.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/flow.py CHANGED
@@ -9,39 +9,41 @@ from blksprs.utils.tools import stride
9
9
  from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
10
 
11
11
 
12
- @triton_op("blksprs::flow_pull", mutates_args={})
12
+ @triton_op("blksprs::flow_pull_forward", mutates_args={})
13
13
  def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
14
14
  sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
15
15
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
16
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
17
- dtype=x.dtype, device=x.device)
18
-
19
- x_b, x_r, x_c = x.size()
20
- x_b_s, x_r_s, x_c_s = stride(x)
21
- o_b, o_r, o_c = output.size()
22
- o_b_s, o_r_s, o_c_s = stride(output)
23
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
24
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
25
- s_lut_r, s_lut_c = sparsity_lut.size()
26
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
27
-
28
- triton_grid = lambda meta: [o_b,
29
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
30
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
31
-
32
- (wrap_triton(flow_pull_kernel)[triton_grid]
33
- (x,
34
- x_b, x_b_s, x_r_s, x_c_s,
35
- output,
36
- o_b, o_b_s, o_r_s, o_c_s,
37
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
38
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
39
- sparsity_reverse_lut,
40
- sparsity_block_size))
41
-
42
- return output
43
-
44
-
16
+ with torch.no_grad():
17
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
18
+ dtype=x.dtype, device=x.device)
19
+
20
+ x_b, x_r, x_c = x.size()
21
+ x_b_s, x_r_s, x_c_s = stride(x)
22
+ o_b, o_r, o_c = output.size()
23
+ o_b_s, o_r_s, o_c_s = stride(output)
24
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
25
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
26
+ s_lut_r, s_lut_c = sparsity_lut.size()
27
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
28
+
29
+ triton_grid = lambda meta: [o_b,
30
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
31
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
32
+
33
+ (wrap_triton(flow_pull_kernel)[triton_grid]
34
+ (x,
35
+ x_b, x_b_s, x_r_s, x_c_s,
36
+ output,
37
+ o_b, o_b_s, o_r_s, o_c_s,
38
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
39
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
40
+ sparsity_reverse_lut,
41
+ sparsity_block_size))
42
+
43
+ return output
44
+
45
+
46
+ # noinspection PyUnusedLocal
45
47
  @triton.autotune(
46
48
  configs=get_autotune_configs(),
47
49
  key=["sparsity_block_size"],
@@ -99,38 +101,40 @@ def flow_pull_kernel(x,
99
101
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
100
102
 
101
103
 
102
- @triton_op("blksprs::flow_push", mutates_args={})
104
+ @triton_op("blksprs::flow_push_forward", mutates_args={})
103
105
  def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
104
106
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
105
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
106
- dtype=x.dtype, device=x.device)
107
-
108
- x_b, x_r, x_c = x.size()
109
- x_b_s, x_r_s, x_c_s = stride(x)
110
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
111
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
112
- s_lut_r, s_lut_c = sparsity_lut.size()
113
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
114
- o_b, o_r, o_c = output.size()
115
- o_b_s, o_r_s, o_c_s = stride(output)
116
-
117
- triton_grid = lambda meta: [x_b,
118
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
119
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
120
-
121
- (wrap_triton(flow_push_kernel)[triton_grid]
122
- (x,
123
- x_b, x_b_s, x_r_s, x_c_s,
124
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
125
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
126
- sparsity_reverse_lut,
127
- output,
128
- o_b, o_b_s, o_r_s, o_c_s,
129
- sparsity_block_size))
130
-
131
- return output
132
-
133
-
107
+ with torch.no_grad():
108
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
109
+ dtype=x.dtype, device=x.device)
110
+
111
+ x_b, x_r, x_c = x.size()
112
+ x_b_s, x_r_s, x_c_s = stride(x)
113
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
114
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
115
+ s_lut_r, s_lut_c = sparsity_lut.size()
116
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
117
+ o_b, o_r, o_c = output.size()
118
+ o_b_s, o_r_s, o_c_s = stride(output)
119
+
120
+ triton_grid = lambda meta: [x_b,
121
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
122
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
123
+
124
+ (wrap_triton(flow_push_kernel)[triton_grid]
125
+ (x,
126
+ x_b, x_b_s, x_r_s, x_c_s,
127
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
128
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
129
+ sparsity_reverse_lut,
130
+ output,
131
+ o_b, o_b_s, o_r_s, o_c_s,
132
+ sparsity_block_size))
133
+
134
+ return output
135
+
136
+
137
+ # noinspection PyUnusedLocal
134
138
  @triton.autotune(
135
139
  configs=get_autotune_configs(),
136
140
  key=["sparsity_block_size"],
blksprs/ops/matmul.py CHANGED
@@ -55,53 +55,54 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
55
55
  sparsity_block_size, lut["n_sparse_blocks"]))
56
56
 
57
57
 
58
- @triton_op("blksprs::matmul", mutates_args={})
58
+ @triton_op("blksprs::matmul_forward", mutates_args={})
59
59
  def matmul_forward(x: Tensor, y: Tensor,
60
60
  sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
61
61
  sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
62
62
  _: Tensor, sparsity_lut_o: Tensor,
63
63
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
64
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
65
- dtype=x.dtype, device=x.device)
66
-
67
- x_b, x_r, x_c = x.size()
68
- x_b_s, x_r_s, x_c_s = stride(x)
69
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
70
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
71
- y_b, y_r, y_c = y.size()
72
- y_b_s, y_r_s, y_c_s = stride(y)
73
- s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
74
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
75
- o_b, o_r, o_c = output.size()
76
- o_b_s, o_r_s, o_c_s = stride(output)
77
- s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
78
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
79
-
80
- triton_grid = lambda meta: [o_b,
81
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
82
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
83
-
84
- (wrap_triton(matmul_kernel)[triton_grid]
85
- (x,
86
- x_b, x_b_s, x_r_s, x_c_s,
87
- s_l_x_b, s_l_x_b_s, s_l_x_r_s,
88
- s_l_x_c, s_l_x_c_s,
89
- sparsity_reverse_lut_x,
90
- y,
91
- y_b, y_b_s, y_r_s, y_c_s,
92
- s_l_y_b, s_l_y_b_s, s_l_y_r_s,
93
- s_l_y_c_s,
94
- sparsity_reverse_lut_y,
95
- output,
96
- o_b, o_b_s, o_r_s, o_c_s,
97
- sparsity_lut_o,
98
- s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
99
- sparsity_block_size))
100
-
101
- return output
102
-
103
-
104
- def matmul_backward(ctx, grad_output):
64
+ with torch.no_grad():
65
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
66
+ dtype=x.dtype, device=x.device)
67
+
68
+ x_b, x_r, x_c = x.size()
69
+ x_b_s, x_r_s, x_c_s = stride(x)
70
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
71
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
72
+ y_b, y_r, y_c = y.size()
73
+ y_b_s, y_r_s, y_c_s = stride(y)
74
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
75
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
76
+ o_b, o_r, o_c = output.size()
77
+ o_b_s, o_r_s, o_c_s = stride(output)
78
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
79
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
80
+
81
+ triton_grid = lambda meta: [o_b,
82
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
83
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
84
+
85
+ (wrap_triton(matmul_kernel)[triton_grid]
86
+ (x,
87
+ x_b, x_b_s, x_r_s, x_c_s,
88
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s,
89
+ s_l_x_c, s_l_x_c_s,
90
+ sparsity_reverse_lut_x,
91
+ y,
92
+ y_b, y_b_s, y_r_s, y_c_s,
93
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
94
+ s_l_y_c_s,
95
+ sparsity_reverse_lut_y,
96
+ output,
97
+ o_b, o_b_s, o_r_s, o_c_s,
98
+ sparsity_lut_o,
99
+ s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
100
+ sparsity_block_size))
101
+
102
+ return output
103
+
104
+
105
+ def matmul_wrapper_backward(ctx, grad_output):
105
106
  x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
106
107
  sparsity_block_size = ctx.sparsity_block_size
107
108
 
@@ -187,20 +188,16 @@ def matmul_kernel(x,
187
188
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
188
189
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
189
190
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
190
- blk_x_msk = ((blk_x_idx >= 0 and
191
- blk_x_idx < x_b * x_b_s) and
192
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
193
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
191
+ blk_x_msk = (blk_x_idx >= 0 and
192
+ blk_x_idx < x_b * x_b_s)
194
193
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
195
194
 
196
195
  blk_y_idx = ((rev_idx_spa_y * y_b_s) +
197
196
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
198
197
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
199
198
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
200
- blk_y_msk = ((blk_y_idx >= 0 and
201
- blk_y_idx < y_b * y_b_s) and
202
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
203
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
199
+ blk_y_msk = (blk_y_idx >= 0 and
200
+ blk_y_idx < y_b * y_b_s)
204
201
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
205
202
 
206
203
  # Perform matrix multiplication
@@ -213,10 +210,8 @@ def matmul_kernel(x,
213
210
  blk_o_idx = ((pid_blk * o_b_s) +
214
211
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
215
212
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
216
- blk_o_msk = ((blk_o_idx >= 0 and
217
- blk_o_idx < o_b * o_b_s) and
218
- (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
219
- tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
213
+ blk_o_msk = (blk_o_idx >= 0 and
214
+ blk_o_idx < o_b * o_b_s)
220
215
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
221
216
 
222
217
 
@@ -262,4 +257,4 @@ def matmul_setup_context(ctx, inputs, output):
262
257
  ctx.sparsity_block_size = sparsity_block_size
263
258
 
264
259
 
265
- matmul_forward.register_autograd(matmul_backward, setup_context=matmul_setup_context)
260
+ matmul_forward.register_autograd(matmul_wrapper_backward, setup_context=matmul_setup_context)
@@ -55,36 +55,37 @@ def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
55
55
  return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size)
56
56
 
57
57
 
58
- @triton_op("blksprs::broadcast_add", mutates_args={})
58
+ @triton_op("blksprs::broadcast_add_forward", mutates_args={})
59
59
  def broadcast_add_forward(x: Tensor, y: Tensor,
60
60
  sparsity_lut_o: Tensor,
61
61
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
62
- output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
63
-
64
- x_b, x_c = x.size()
65
- x_b_s, x_c_s = stride(x)
66
- y_b, y_c = y.size()
67
- y_b_s, y_c_s = stride(y)
68
- o_b, o_r, o_c = output.size()
69
- o_b_s, o_r_s, o_c_s = stride(output)
70
- s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
71
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
72
-
73
- triton_grid = lambda meta: [o_b,
74
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
75
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
76
-
77
- (wrap_triton(broadcast_add_kernel)[triton_grid]
78
- (x,
79
- x_b, x_b_s, x_c_s,
80
- y,
81
- y_b, y_b_s, y_c_s,
82
- output,
83
- o_b, o_b_s, o_r_s, o_c_s,
84
- sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
85
- sparsity_block_size))
86
-
87
- return BlksprsTensor(output)
62
+ with torch.no_grad():
63
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
64
+
65
+ x_b, x_c = x.size()
66
+ x_b_s, x_c_s = stride(x)
67
+ y_b, y_c = y.size()
68
+ y_b_s, y_c_s = stride(y)
69
+ o_b, o_r, o_c = output.size()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
72
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
73
+
74
+ triton_grid = lambda meta: [o_b,
75
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
76
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
77
+
78
+ (wrap_triton(broadcast_add_kernel)[triton_grid]
79
+ (x,
80
+ x_b, x_b_s, x_c_s,
81
+ y,
82
+ y_b, y_b_s, y_c_s,
83
+ output,
84
+ o_b, o_b_s, o_r_s, o_c_s,
85
+ sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
86
+ sparsity_block_size))
87
+
88
+ return output
88
89
 
89
90
 
90
91
  @triton.autotune(