blksprs 1.11__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/flow.py CHANGED
@@ -1,25 +1,68 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.utils.tools import stride, get_triton_block_size
8
+ from blksprs.utils.tools import stride, get_autotune_configs
7
9
 
8
10
 
11
+ @triton_op("blksprs::flow_pull", mutates_args={})
12
+ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
13
+ sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
14
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
15
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
16
+ dtype=x.dtype, device=x.device)
17
+
18
+ x_b, x_r, x_c = x.size()
19
+ x_b_s, x_r_s, x_c_s = stride(x)
20
+ o_b, o_r, o_c = output.size()
21
+ o_b_s, o_r_s, o_c_s = stride(output)
22
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
23
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
24
+ s_lut_r, s_lut_c = sparsity_lut.size()
25
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
26
+
27
+ triton_grid = lambda meta: [o_b,
28
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
29
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
30
+
31
+ (wrap_triton(flow_pull_kernel)[triton_grid]
32
+ (x,
33
+ x_b, x_b_s, x_r_s, x_c_s,
34
+ output,
35
+ o_b, o_b_s, o_r_s, o_c_s,
36
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
37
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
38
+ sparsity_reverse_lut,
39
+ sparsity_block_size))
40
+
41
+ return output
42
+
43
+
44
+ @triton.autotune(
45
+ configs=get_autotune_configs(),
46
+ key=[],
47
+ )
9
48
  @triton.jit
10
- def kernel_blocksparse_flow_pull(x,
11
- x_b, x_b_s, x_r_s, x_c_s,
12
- o,
13
- o_b, o_b_s, o_r_s, o_c_s,
14
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
15
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
16
- r_lut,
17
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
49
+ def flow_pull_kernel(x,
50
+ x_b, x_b_s, x_r_s, x_c_s,
51
+ o,
52
+ o_b, o_b_s, o_r_s, o_c_s,
53
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
54
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
55
+ r_lut,
56
+ sparsity_block_size,
57
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
18
58
  # Get triton block indices
19
59
  pid_blk = tl.program_id(axis=0)
20
60
  pid_row = tl.program_id(axis=1)
21
61
  pid_col = tl.program_id(axis=2)
22
62
 
63
+ # Get valid triton block size
64
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
65
+
23
66
  # Get sparsity index of current output block consisting of its batch, row, and column index
24
67
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
25
68
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -42,32 +85,79 @@ def kernel_blocksparse_flow_pull(x,
42
85
 
43
86
  if rev_idx_spa >= 0:
44
87
  blk_x_idx = (rev_idx_spa * x_b_s +
45
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
46
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
47
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
88
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
89
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
90
+ blk_x_msk = ((blk_x_idx >= 0 and
91
+ blk_x_idx < x_b * x_b_s) and
92
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
93
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
48
94
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
49
95
 
50
96
  blk_o_idx = (pid_blk * o_b_s +
51
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
52
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
53
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
97
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
98
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
99
+ blk_o_msk = ((blk_o_idx >= 0 and
100
+ blk_o_idx < o_b * o_b_s) and
101
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
102
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
54
103
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
55
104
 
56
105
 
106
+ @triton_op("blksprs::flow_push", mutates_args={})
107
+ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
108
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
109
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
110
+ dtype=x.dtype, device=x.device)
111
+
112
+ x_b, x_r, x_c = x.size()
113
+ x_b_s, x_r_s, x_c_s = stride(x)
114
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
115
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
116
+ s_lut_r, s_lut_c = sparsity_lut.size()
117
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
118
+ o_b, o_r, o_c = output.size()
119
+ o_b_s, o_r_s, o_c_s = stride(output)
120
+
121
+ triton_grid = lambda meta: [x_b,
122
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
123
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
124
+
125
+ (wrap_triton(flow_push_kernel)[triton_grid]
126
+ (x,
127
+ x_b, x_b_s, x_r_s, x_c_s,
128
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
129
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
130
+ sparsity_reverse_lut,
131
+ output,
132
+ o_b, o_b_s, o_r_s, o_c_s,
133
+ sparsity_block_size))
134
+
135
+ return output
136
+
137
+
138
+ @triton.autotune(
139
+ configs=get_autotune_configs(),
140
+ key=[],
141
+ reset_to_zero=["o"]
142
+ )
57
143
  @triton.jit
58
- def kernel_blocksparse_flow_push(x,
59
- x_b, x_b_s, x_r_s, x_c_s,
60
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
61
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
62
- r_lut,
63
- o,
64
- o_b, o_b_s, o_r_s, o_c_s,
65
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
144
+ def flow_push_kernel(x,
145
+ x_b, x_b_s, x_r_s, x_c_s,
146
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
147
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
148
+ r_lut,
149
+ o,
150
+ o_b, o_b_s, o_r_s, o_c_s,
151
+ sparsity_block_size,
152
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
66
153
  # Get triton block indices
67
154
  pid_blk = tl.program_id(axis=0)
68
155
  pid_row = tl.program_id(axis=1)
69
156
  pid_col = tl.program_id(axis=2)
70
157
 
158
+ # Get valid triton block size
159
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
160
+
71
161
  # Get sparsity index of current input block consisting of its batch, row, and column index
72
162
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
73
163
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -90,90 +180,19 @@ def kernel_blocksparse_flow_push(x,
90
180
 
91
181
  if rev_idx_spa >= 0:
92
182
  blk_x_idx = (pid_blk * x_b_s +
93
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
94
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
95
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
183
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
184
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
185
+ blk_x_msk = ((blk_x_idx >= 0 and
186
+ blk_x_idx < x_b * x_b_s) and
187
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
188
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
96
189
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
97
190
 
98
191
  blk_o_idx = (rev_idx_spa * o_b_s +
99
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
100
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
101
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
192
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
193
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
194
+ blk_o_msk = ((blk_o_idx >= 0 and
195
+ blk_o_idx < o_b * o_b_s) and
196
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
197
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
102
198
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
103
-
104
-
105
- def flow_forward_pull(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
106
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
107
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
108
- dtype=x.dtype, device=x.device)
109
-
110
- x_b, x_r, x_c = x.size()
111
- x_b_s, x_r_s, x_c_s = stride(x)
112
- o_b, o_r, o_c = output.size()
113
- o_b_s, o_r_s, o_c_s = stride(output)
114
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
115
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
116
- s_lut_r, s_lut_c = sparsity_lut.size()
117
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
118
-
119
- if triton_block_size is None:
120
- triton_block_size = get_triton_block_size(sparsity_block_size)
121
-
122
- triton_grid = lambda meta: [o_b,
123
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
124
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
125
-
126
- (kernel_blocksparse_flow_pull[triton_grid]
127
- (x,
128
- x_b, x_b_s, x_r_s, x_c_s,
129
- output,
130
- o_b, o_b_s, o_r_s, o_c_s,
131
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
132
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
133
- sparsity_reverse_lut,
134
- triton_block_size))
135
-
136
- # Save for backward pass
137
- ctx.sparsity_block_size = sparsity_block_size
138
- ctx.triton_block_size = triton_block_size
139
-
140
- return output
141
-
142
-
143
- def flow_forward_push(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
144
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
145
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
146
- dtype=x.dtype, device=x.device)
147
-
148
- x_b, x_r, x_c = x.size()
149
- x_b_s, x_r_s, x_c_s = stride(x)
150
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
151
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
152
- s_lut_r, s_lut_c = sparsity_lut.size()
153
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
154
- o_b, o_r, o_c = output.size()
155
- o_b_s, o_r_s, o_c_s = stride(output)
156
-
157
- if triton_block_size is None:
158
- triton_block_size = get_triton_block_size(sparsity_block_size)
159
-
160
- triton_grid = lambda meta: [x_b,
161
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
162
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
163
-
164
- (kernel_blocksparse_flow_push[triton_grid]
165
- (x,
166
- x_b, x_b_s, x_r_s, x_c_s,
167
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
168
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
169
- sparsity_reverse_lut,
170
- output,
171
- o_b, o_b_s, o_r_s, o_c_s,
172
- triton_block_size))
173
-
174
- # Save for backward pass
175
- if ctx is not None:
176
- ctx.sparsity_block_size = sparsity_block_size
177
- ctx.triton_block_size = triton_block_size
178
-
179
- return output