blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/flow.py CHANGED
@@ -1,25 +1,68 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.utils.tools import stride, get_triton_block_size
8
+ from blksprs.utils.tools import stride, get_autotune_configs
7
9
 
8
10
 
11
+ @triton_op("blksprs::flow_pull", mutates_args={})
12
+ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
13
+ sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
14
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
15
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
16
+ dtype=x.dtype, device=x.device)
17
+
18
+ x_b, x_r, x_c = x.size()
19
+ x_b_s, x_r_s, x_c_s = stride(x)
20
+ o_b, o_r, o_c = output.size()
21
+ o_b_s, o_r_s, o_c_s = stride(output)
22
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
23
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
24
+ s_lut_r, s_lut_c = sparsity_lut.size()
25
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
26
+
27
+ triton_grid = lambda meta: [o_b,
28
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
29
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
30
+
31
+ (wrap_triton(flow_pull_kernel)[triton_grid]
32
+ (x,
33
+ x_b, x_b_s, x_r_s, x_c_s,
34
+ output,
35
+ o_b, o_b_s, o_r_s, o_c_s,
36
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
37
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
38
+ sparsity_reverse_lut,
39
+ sparsity_block_size))
40
+
41
+ return output
42
+
43
+
44
+ @triton.autotune(
45
+ configs=get_autotune_configs(),
46
+ key=[],
47
+ )
9
48
  @triton.jit
10
- def kernel_blocksparse_flow_pull(x,
11
- x_b, x_b_s, x_r_s, x_c_s,
12
- o,
13
- o_b, o_b_s, o_r_s, o_c_s,
14
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
15
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
16
- r_lut,
17
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
49
+ def flow_pull_kernel(x,
50
+ x_b, x_b_s, x_r_s, x_c_s,
51
+ o,
52
+ o_b, o_b_s, o_r_s, o_c_s,
53
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
54
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
55
+ r_lut,
56
+ sparsity_block_size,
57
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
18
58
  # Get triton block indices
19
59
  pid_blk = tl.program_id(axis=0)
20
60
  pid_row = tl.program_id(axis=1)
21
61
  pid_col = tl.program_id(axis=2)
22
62
 
63
+ # Get valid triton block size
64
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
65
+
23
66
  # Get sparsity index of current output block consisting of its batch, row, and column index
24
67
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
25
68
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -40,37 +83,81 @@ def kernel_blocksparse_flow_pull(x,
40
83
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
41
84
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
42
85
 
43
- if rev_idx_spa == -1:
44
- tl.device_assert(False)
45
- return
86
+ if rev_idx_spa >= 0:
87
+ blk_x_idx = (rev_idx_spa * x_b_s +
88
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
89
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
90
+ blk_x_msk = ((blk_x_idx >= 0 and
91
+ blk_x_idx < x_b * x_b_s) and
92
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
93
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
94
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
95
+
96
+ blk_o_idx = (pid_blk * o_b_s +
97
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
98
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
99
+ blk_o_msk = ((blk_o_idx >= 0 and
100
+ blk_o_idx < o_b * o_b_s) and
101
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
102
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
103
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
104
+
105
+
106
+ @triton_op("blksprs::flow_push", mutates_args={})
107
+ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
108
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
109
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
110
+ dtype=x.dtype, device=x.device)
46
111
 
47
- blk_x_idx = (rev_idx_spa * x_b_s +
48
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
49
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
50
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
51
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
112
+ x_b, x_r, x_c = x.size()
113
+ x_b_s, x_r_s, x_c_s = stride(x)
114
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
115
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
116
+ s_lut_r, s_lut_c = sparsity_lut.size()
117
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
118
+ o_b, o_r, o_c = output.size()
119
+ o_b_s, o_r_s, o_c_s = stride(output)
120
+
121
+ triton_grid = lambda meta: [x_b,
122
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
123
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
124
+
125
+ (wrap_triton(flow_push_kernel)[triton_grid]
126
+ (x,
127
+ x_b, x_b_s, x_r_s, x_c_s,
128
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
129
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
130
+ sparsity_reverse_lut,
131
+ output,
132
+ o_b, o_b_s, o_r_s, o_c_s,
133
+ sparsity_block_size))
52
134
 
53
- blk_o_idx = (pid_blk * o_b_s +
54
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
55
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
56
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
57
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
135
+ return output
58
136
 
59
137
 
138
+ @triton.autotune(
139
+ configs=get_autotune_configs(),
140
+ key=[],
141
+ reset_to_zero=["o"]
142
+ )
60
143
  @triton.jit
61
- def kernel_blocksparse_flow_push(x,
62
- x_b, x_b_s, x_r_s, x_c_s,
63
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
64
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
65
- r_lut,
66
- o,
67
- o_b, o_b_s, o_r_s, o_c_s,
68
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
144
+ def flow_push_kernel(x,
145
+ x_b, x_b_s, x_r_s, x_c_s,
146
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
147
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
148
+ r_lut,
149
+ o,
150
+ o_b, o_b_s, o_r_s, o_c_s,
151
+ sparsity_block_size,
152
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
69
153
  # Get triton block indices
70
154
  pid_blk = tl.program_id(axis=0)
71
155
  pid_row = tl.program_id(axis=1)
72
156
  pid_col = tl.program_id(axis=2)
73
157
 
158
+ # Get valid triton block size
159
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
160
+
74
161
  # Get sparsity index of current input block consisting of its batch, row, and column index
75
162
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
76
163
  spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
@@ -91,56 +178,21 @@ def kernel_blocksparse_flow_push(x,
91
178
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
92
179
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
93
180
 
94
- if rev_idx_spa == -1:
95
- tl.device_assert(False)
96
- return
97
-
98
- blk_x_idx = (pid_blk * x_b_s +
99
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
100
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
101
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
102
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
103
-
104
- blk_o_idx = (rev_idx_spa * o_b_s +
105
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
106
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
107
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
108
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
109
-
110
-
111
- def flow_forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
112
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
113
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
114
- dtype=x.dtype, device=x.device)
115
-
116
- x_b, x_r, x_c = x.size()
117
- x_b_s, x_r_s, x_c_s = stride(x)
118
- o_b, o_r, o_c = output.size()
119
- o_b_s, o_r_s, o_c_s = stride(output)
120
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
121
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
122
- s_lut_r, s_lut_c = sparsity_lut.size()
123
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
124
-
125
- if triton_block_size is None:
126
- triton_block_size = get_triton_block_size(sparsity_block_size)
127
-
128
- triton_grid = lambda meta: [o_b,
129
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
130
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
131
-
132
- (kernel_blocksparse_flow_pull[triton_grid]
133
- (x,
134
- x_b, x_b_s, x_r_s, x_c_s,
135
- output,
136
- o_b, o_b_s, o_r_s, o_c_s,
137
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
138
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
139
- sparsity_reverse_lut,
140
- triton_block_size))
141
-
142
- # Save for backward pass
143
- ctx.sparsity_block_size = sparsity_block_size
144
- ctx.triton_block_size = triton_block_size
145
-
146
- return output
181
+ if rev_idx_spa >= 0:
182
+ blk_x_idx = (pid_blk * x_b_s +
183
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
184
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
185
+ blk_x_msk = ((blk_x_idx >= 0 and
186
+ blk_x_idx < x_b * x_b_s) and
187
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
188
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
189
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
190
+
191
+ blk_o_idx = (rev_idx_spa * o_b_s +
192
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
193
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
194
+ blk_o_msk = ((blk_o_idx >= 0 and
195
+ blk_o_idx < o_b * o_b_s) and
196
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
197
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
198
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)