blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/flow.py CHANGED
@@ -1,20 +1,65 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
- from blksprs.utils.tools import stride, get_triton_block_size
7
-
8
-
8
+ from blksprs.utils.tools import stride
9
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
+
11
+
12
+ @triton_op("blksprs::flow_pull_forward", mutates_args={})
13
+ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
14
+ sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
15
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
16
+ with torch.no_grad():
17
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
18
+ dtype=x.dtype, device=x.device)
19
+
20
+ x_b, x_r, x_c = x.size()
21
+ x_b_s, x_r_s, x_c_s = stride(x)
22
+ o_b, o_r, o_c = output.size()
23
+ o_b_s, o_r_s, o_c_s = stride(output)
24
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
25
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
26
+ s_lut_r, s_lut_c = sparsity_lut.size()
27
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
28
+
29
+ triton_grid = lambda meta: [o_b,
30
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
31
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
32
+
33
+ (wrap_triton(flow_pull_kernel)[triton_grid]
34
+ (x,
35
+ x_b, x_b_s, x_r_s, x_c_s,
36
+ output,
37
+ o_b, o_b_s, o_r_s, o_c_s,
38
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
39
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
40
+ sparsity_reverse_lut,
41
+ sparsity_block_size))
42
+
43
+ return output
44
+
45
+
46
+ # noinspection PyUnusedLocal
47
+ @triton.autotune(
48
+ configs=get_autotune_configs(),
49
+ key=["sparsity_block_size"],
50
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
51
+ reset_to_zero=["o"]
52
+ )
9
53
  @triton.jit
10
- def kernel_blocksparse_flow_pull(x,
11
- x_b, x_b_s, x_r_s, x_c_s,
12
- o,
13
- o_b, o_b_s, o_r_s, o_c_s,
14
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
15
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
16
- r_lut,
17
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
54
+ def flow_pull_kernel(x,
55
+ x_b, x_b_s, x_r_s, x_c_s,
56
+ o,
57
+ o_b, o_b_s, o_r_s, o_c_s,
58
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
59
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
60
+ r_lut,
61
+ sparsity_block_size,
62
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
18
63
  # Get triton block indices
19
64
  pid_blk = tl.program_id(axis=0)
20
65
  pid_row = tl.program_id(axis=1)
@@ -44,25 +89,68 @@ def kernel_blocksparse_flow_pull(x,
44
89
  blk_x_idx = (rev_idx_spa * x_b_s +
45
90
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
46
91
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
47
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
92
+ blk_x_msk = (blk_x_idx >= 0 and
93
+ blk_x_idx < x_b * x_b_s)
48
94
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
49
95
 
50
96
  blk_o_idx = (pid_blk * o_b_s +
51
97
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
52
98
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
53
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
99
+ blk_o_msk = (blk_o_idx >= 0 and
100
+ blk_o_idx < o_b * o_b_s)
54
101
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
55
102
 
56
103
 
104
+ @triton_op("blksprs::flow_push_forward", mutates_args={})
105
+ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
106
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
107
+ with torch.no_grad():
108
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
109
+ dtype=x.dtype, device=x.device)
110
+
111
+ x_b, x_r, x_c = x.size()
112
+ x_b_s, x_r_s, x_c_s = stride(x)
113
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
114
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
115
+ s_lut_r, s_lut_c = sparsity_lut.size()
116
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
117
+ o_b, o_r, o_c = output.size()
118
+ o_b_s, o_r_s, o_c_s = stride(output)
119
+
120
+ triton_grid = lambda meta: [x_b,
121
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
122
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
123
+
124
+ (wrap_triton(flow_push_kernel)[triton_grid]
125
+ (x,
126
+ x_b, x_b_s, x_r_s, x_c_s,
127
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
128
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
129
+ sparsity_reverse_lut,
130
+ output,
131
+ o_b, o_b_s, o_r_s, o_c_s,
132
+ sparsity_block_size))
133
+
134
+ return output
135
+
136
+
137
+ # noinspection PyUnusedLocal
138
+ @triton.autotune(
139
+ configs=get_autotune_configs(),
140
+ key=["sparsity_block_size"],
141
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
142
+ reset_to_zero=["o"]
143
+ )
57
144
  @triton.jit
58
- def kernel_blocksparse_flow_push(x,
59
- x_b, x_b_s, x_r_s, x_c_s,
60
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
61
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
62
- r_lut,
63
- o,
64
- o_b, o_b_s, o_r_s, o_c_s,
65
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
145
+ def flow_push_kernel(x,
146
+ x_b, x_b_s, x_r_s, x_c_s,
147
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
148
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
149
+ r_lut,
150
+ o,
151
+ o_b, o_b_s, o_r_s, o_c_s,
152
+ sparsity_block_size,
153
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
66
154
  # Get triton block indices
67
155
  pid_blk = tl.program_id(axis=0)
68
156
  pid_row = tl.program_id(axis=1)
@@ -92,88 +180,13 @@ def kernel_blocksparse_flow_push(x,
92
180
  blk_x_idx = (pid_blk * x_b_s +
93
181
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
94
182
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
95
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
183
+ blk_x_msk = (blk_x_idx >= 0 and
184
+ blk_x_idx < x_b * x_b_s)
96
185
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
97
186
 
98
187
  blk_o_idx = (rev_idx_spa * o_b_s +
99
188
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
100
189
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
101
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
190
+ blk_o_msk = (blk_o_idx >= 0 and
191
+ blk_o_idx < o_b * o_b_s)
102
192
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
103
-
104
-
105
- def flow_forward_pull(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
106
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
107
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
108
- dtype=x.dtype, device=x.device)
109
-
110
- x_b, x_r, x_c = x.size()
111
- x_b_s, x_r_s, x_c_s = stride(x)
112
- o_b, o_r, o_c = output.size()
113
- o_b_s, o_r_s, o_c_s = stride(output)
114
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
115
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
116
- s_lut_r, s_lut_c = sparsity_lut.size()
117
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
118
-
119
- if triton_block_size is None:
120
- triton_block_size = get_triton_block_size(sparsity_block_size)
121
-
122
- triton_grid = lambda meta: [o_b,
123
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
124
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
125
-
126
- (kernel_blocksparse_flow_pull[triton_grid]
127
- (x,
128
- x_b, x_b_s, x_r_s, x_c_s,
129
- output,
130
- o_b, o_b_s, o_r_s, o_c_s,
131
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
132
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
133
- sparsity_reverse_lut,
134
- triton_block_size))
135
-
136
- # Save for backward pass
137
- ctx.sparsity_block_size = sparsity_block_size
138
- ctx.triton_block_size = triton_block_size
139
-
140
- return output
141
-
142
-
143
- def flow_forward_push(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
144
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
145
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
146
- dtype=x.dtype, device=x.device)
147
-
148
- x_b, x_r, x_c = x.size()
149
- x_b_s, x_r_s, x_c_s = stride(x)
150
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
151
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
152
- s_lut_r, s_lut_c = sparsity_lut.size()
153
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
154
- o_b, o_r, o_c = output.size()
155
- o_b_s, o_r_s, o_c_s = stride(output)
156
-
157
- if triton_block_size is None:
158
- triton_block_size = get_triton_block_size(sparsity_block_size)
159
-
160
- triton_grid = lambda meta: [x_b,
161
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
162
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
163
-
164
- (kernel_blocksparse_flow_push[triton_grid]
165
- (x,
166
- x_b, x_b_s, x_r_s, x_c_s,
167
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
168
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
169
- sparsity_reverse_lut,
170
- output,
171
- o_b, o_b_s, o_r_s, o_c_s,
172
- triton_block_size))
173
-
174
- # Save for backward pass
175
- if ctx is not None:
176
- ctx.sparsity_block_size = sparsity_block_size
177
- ctx.triton_block_size = triton_block_size
178
-
179
- return output