blksprs 1.3__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py ADDED
@@ -0,0 +1,18 @@
1
+ from blksprs.ops.conversion import to_dense, to_sparse
2
+ from blksprs.ops.distribution import gather, scatter, scatter_reduce
3
+ from blksprs.ops.exp import exp
4
+ from blksprs.ops.matmul import matmul
5
+ from blksprs.ops.softmax import softmax
6
+ from blksprs.ops.transpose import transpose
7
+
8
+ class layout:
9
+ from blksprs.layouting.distribution_layout import build_distribution_layout
10
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
11
+
12
+ class misc:
13
+ from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
14
+ from blksprs.misc.repeat_interleave import repeat_interleave
15
+ from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
16
+
17
+ class util:
18
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
@@ -31,7 +31,7 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
31
31
  sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
32
32
 
33
33
  output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
34
- device=indices.device, dtype=torch.int32)
34
+ dtype=torch.bool, device=indices.device)
35
35
 
36
36
  i_b, i_r, i_c = indices.size()
37
37
  i_b_s, i_r_s, i_c_s = indices.stride()
@@ -27,7 +27,7 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
27
27
  validate_device(x)
28
28
 
29
29
  output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
30
- device=x.device, dtype=torch.int32)
30
+ dtype=torch.bool, device=x.device)
31
31
 
32
32
  x_b, x_r, x_c = x.size()
33
33
  x_b_s, x_r_s, x_c_s = x.stride()
@@ -117,7 +117,7 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
117
117
  o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
118
118
  o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
119
119
 
120
- output = torch.zeros(o_b, o_r, o_c, device=x.device, dtype=torch.int32)
120
+ output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
121
121
 
122
122
  x_b, x_r, x_c = x.size()
123
123
  x_b_s, x_r_s, x_c_s = x.stride()
@@ -8,8 +8,8 @@ from blksprs.utils.validation import validate_contiguous, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
10
10
 
11
- def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
11
+ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
13
  """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
14
14
  compressed form.
15
15
 
@@ -25,6 +25,9 @@ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
25
25
  output tensor corresponds to x(i) + y(j).
26
26
 
27
27
  """
28
+ x = x.contiguous()
29
+ y = y.contiguous()
30
+
28
31
  validate_device(x, y)
29
32
  validate_contiguous(x, y)
30
33
  if x.size(-1) != y.size(-1):
@@ -70,12 +73,12 @@ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
70
73
  return output
71
74
 
72
75
 
73
- def broadcast_subtraction(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
74
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
75
- """Wrapper for ``broadcast_addition`` with negated y.
76
+ def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
77
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
78
+ """Wrapper for ``broadcast_add`` with negated y.
76
79
 
77
80
  """
78
- return broadcast_addition(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
81
+ return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
79
82
 
80
83
 
81
84
  @triton.jit
@@ -27,6 +27,8 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
27
27
  Tensor: The sparsity layout of the resulting output tensor.
28
28
 
29
29
  """
30
+ x = x.contiguous()
31
+
30
32
  validate_dimensions(x)
31
33
  validate_contiguous(x)
32
34
  validate_device(x)
@@ -0,0 +1,390 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
8
+ validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
12
+ flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
13
+ """Computes the row-wise sum of a block-sparse tensor.
14
+
15
+ Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
16
+ of the corresponding row.
17
+
18
+ Note:
19
+ If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
20
+
21
+ Args:
22
+ x (Tensor): A block-sparse tensor in compressed form.
23
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
+ sparsity_block_size (int): The size of the sparsity blocks.
25
+ flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
26
+ (default ``False``).
27
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
28
+
29
+ Returns:
30
+ tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
+ of the input and the sparsity layout of the output tensor.
32
+
33
+ """
34
+ x = x.contiguous()
35
+
36
+ validate_dimensions(x)
37
+ validate_contiguous(x)
38
+ validate_device(x)
39
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
40
+ validate_sparsity_block_size(sparsity_block_size, x)
41
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
42
+
43
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
44
+
45
+ sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
46
+ sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
47
+ sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
48
+ (sparsity_layout_output_flat == 1) -
49
+ (1 * (sparsity_layout_output_flat == 0)))
50
+
51
+ n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
52
+
53
+ validate_contiguous(sparsity_layout, sparsity_lut,
54
+ sparsity_layout_output, sparsity_reverse_lut_output)
55
+
56
+ output = torch.zeros(size=(n_sparse_blocks_output,
57
+ sparsity_block_size,
58
+ 1 if flag_slice_only else sparsity_block_size),
59
+ device=x.device)
60
+
61
+ x_b, x_r, x_c = x.size()
62
+ x_b_s, x_r_s, x_c_s = x.stride()
63
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
64
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
65
+ o_b, o_r, o_c = output.size()
66
+ o_b_s, o_r_s, o_c_s = output.stride()
67
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
68
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
69
+
70
+ if triton_block_size is None:
71
+ triton_block_size = get_triton_block_size(sparsity_block_size)
72
+
73
+ triton_grid = lambda meta: [x_b,
74
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
75
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
76
+
77
+ (kernel_blocksparse_row_wise_sum[triton_grid]
78
+ (x,
79
+ x_b, x_b_s, x_r_s, x_c_s,
80
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
81
+ output,
82
+ o_b, o_b_s, o_r_s,
83
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
84
+ sparsity_reverse_lut_output,
85
+ triton_block_size))
86
+
87
+ return (output, sparsity_layout_output)
88
+
89
+
90
+ @triton.jit
91
+ def kernel_blocksparse_row_wise_sum(x,
92
+ x_b, x_b_s, x_r_s, x_c_s,
93
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
94
+ o,
95
+ o_b, o_b_s, o_r_s,
96
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
97
+ r_lut_o,
98
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
99
+ pid_blk = tl.program_id(axis=0)
100
+ pid_row = tl.program_id(axis=1)
101
+ pid_col = tl.program_id(axis=2)
102
+
103
+ # Get position of current sparsity block consisting of its batch and row index
104
+ spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
105
+ spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
106
+ spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
107
+
108
+ spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
109
+ spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
110
+ spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
111
+
112
+ # Load reverse sparsity index for current block
113
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
114
+ spa_row * s_l_o_r_s)
115
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
116
+ rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
117
+
118
+ blk_idx = ((pid_blk * x_b_s) +
119
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
120
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
121
+ blk_msk = (blk_idx < x_b * x_b_s)
122
+ blk = tl.load(x + blk_idx, mask=blk_msk)
123
+
124
+ buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
125
+
126
+ o_idx = (rev_idx_spa * o_b_s +
127
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
128
+ (tl.arange(0, 1))[None, :])
129
+ o_msk = (o_idx < o_b * o_b_s)
130
+ tl.atomic_add(o + o_idx, buf, o_msk)
131
+
132
+
133
+ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
134
+ flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
135
+ """Computes the row-wise max of a block-sparse tensor.
136
+
137
+ Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
138
+ maximum of the corresponding row.
139
+
140
+ Note:
141
+ If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
142
+
143
+ Args:
144
+ x (Tensor): A block-sparse tensor in compressed form.
145
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
146
+ sparsity_block_size (int): The size of the sparsity blocks.
147
+ flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
148
+ (default ``False``).
149
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
150
+
151
+ Returns:
152
+ tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
153
+ of the input and the sparsity layout of the output tensor.
154
+
155
+ """
156
+ x = x.contiguous()
157
+
158
+ validate_dimensions(x)
159
+ validate_contiguous(x)
160
+ validate_device(x)
161
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
162
+ validate_sparsity_block_size(sparsity_block_size, x)
163
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
164
+
165
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
166
+
167
+ sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
168
+ sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
169
+ sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
170
+ (sparsity_layout_output_flat == 1) -
171
+ (1 * (sparsity_layout_output_flat == 0)))
172
+
173
+ n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
174
+
175
+ validate_contiguous(sparsity_layout, sparsity_lut,
176
+ sparsity_layout_output, sparsity_reverse_lut_output)
177
+
178
+ output = torch.full(size=(n_sparse_blocks_output,
179
+ sparsity_block_size,
180
+ 1 if flag_slice_only else sparsity_block_size),
181
+ fill_value=float("-inf"),
182
+ device=x.device)
183
+
184
+ x_b, x_r, x_c = x.size()
185
+ x_b_s, x_r_s, x_c_s = x.stride()
186
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
187
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
188
+ o_b, o_r, o_c = output.size()
189
+ o_b_s, o_r_s, o_c_s = output.stride()
190
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
191
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
192
+
193
+ if triton_block_size is None:
194
+ triton_block_size = get_triton_block_size(sparsity_block_size)
195
+
196
+ triton_grid = lambda meta: [x_b,
197
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
198
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
199
+
200
+ (kernel_blocksparse_row_wise_max[triton_grid]
201
+ (x,
202
+ x_b, x_b_s, x_r_s, x_c_s,
203
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
204
+ output,
205
+ o_b, o_b_s, o_r_s,
206
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
207
+ sparsity_reverse_lut_output,
208
+ triton_block_size))
209
+
210
+ return output, sparsity_layout_output
211
+
212
+
213
+ @triton.jit
214
+ def kernel_blocksparse_row_wise_max(x,
215
+ x_b, x_b_s, x_r_s, x_c_s,
216
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
217
+ o,
218
+ o_b, o_b_s, o_r_s,
219
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
220
+ r_lut_o,
221
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
222
+ pid_blk = tl.program_id(axis=0)
223
+ pid_row = tl.program_id(axis=1)
224
+ pid_col = tl.program_id(axis=2)
225
+
226
+ # Get position of current sparsity block consisting of its batch and row index
227
+ spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
228
+ spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
229
+ spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
230
+
231
+ spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
232
+ spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
233
+ spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
234
+
235
+ # Load reverse sparsity index for current block
236
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
237
+ spa_row * s_l_o_r_s)
238
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
239
+ rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
240
+
241
+ blk_idx = ((pid_blk * x_b_s) +
242
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
243
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
244
+ blk_msk = (blk_idx < x_b * x_b_s)
245
+ blk = tl.load(x + blk_idx, mask=blk_msk)
246
+
247
+ buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
248
+
249
+ o_idx = (rev_idx_spa * o_b_s +
250
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
251
+ (tl.arange(0, 1))[None, :])
252
+ o_msk = (o_idx < o_b * o_b_s)
253
+ tl.atomic_max(o + o_idx, buf, o_msk)
254
+
255
+
256
+ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
257
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
258
+ """For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
259
+
260
+ Args:
261
+ x (Tensor): A block-sparse tensor in compressed form.
262
+ sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
263
+ y (Tensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
264
+ sparsity_block_size (int): The size of the sparsity blocks.
265
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
266
+
267
+ Returns:
268
+ Tensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
269
+ compressed form.
270
+
271
+ """
272
+ validate_dimensions(x)
273
+ validate_contiguous(x)
274
+ validate_device(x)
275
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
276
+ validate_sparsity_block_size(sparsity_block_size, x)
277
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
278
+
279
+ sparsity_lut = torch.nonzero(sparsity_layout_x).contiguous()
280
+
281
+ sparsity_layout_rwm, _ = torch.max(sparsity_layout_x, dim=-1, keepdim=True)
282
+ sparsity_layout_rwm_flat = sparsity_layout_rwm.reshape(-1)
283
+ sparsity_reverse_lut_rwm = ((torch.cumsum(sparsity_layout_rwm_flat, dim=-1) - 1) *
284
+ (sparsity_layout_rwm_flat == 1) -
285
+ (1 * (sparsity_layout_rwm_flat == 0)))
286
+
287
+ validate_contiguous(sparsity_layout_x, sparsity_lut, sparsity_reverse_lut_rwm)
288
+
289
+ output = torch.empty_like(x)
290
+
291
+ x_b, x_r, x_c = x.size()
292
+ x_b_s, x_r_s, x_c_s = x.stride()
293
+ s_lut_r, s_lut_c = sparsity_lut.size()
294
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
295
+ y_b, y_r, y_c = y.size()
296
+ y_b_s, y_r_s, y_c_s = y.stride()
297
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
298
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_rwm.stride()
299
+ o_b, o_r, o_c = output.size()
300
+ o_b_s, o_r_s, o_c_s = output.stride()
301
+
302
+ if triton_block_size is None:
303
+ triton_block_size = get_triton_block_size(sparsity_block_size)
304
+
305
+ triton_grid = lambda meta: [o_b,
306
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
307
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
308
+
309
+ (kernel_blocksparse_row_wise_add[triton_grid]
310
+ (x,
311
+ x_b, x_b_s, x_r_s, x_c_s,
312
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
313
+ y, y_b, y_b_s, y_r_s, y_c_s,
314
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
315
+ sparsity_reverse_lut_rwm,
316
+ output,
317
+ o_b, o_b_s, o_r_s, o_c_s,
318
+ triton_block_size
319
+ ))
320
+
321
+ return output
322
+
323
+
324
+ def row_wise_sub(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
325
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
326
+ """Wrapper for ``row_wise_add`` with negated y.
327
+
328
+ """
329
+ return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size, triton_block_size)
330
+
331
+
332
+ @triton.jit
333
+ def kernel_blocksparse_row_wise_add(x,
334
+ x_b, x_b_s, x_r_s, x_c_s,
335
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
336
+ y, y_b, y_b_s, y_r_s, y_c_s,
337
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
338
+ r_lut_y,
339
+ o,
340
+ o_b, o_b_s, o_r_s, o_c_s,
341
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
342
+ # Get triton block indices
343
+ pid_blk = tl.program_id(axis=0)
344
+ pid_row = tl.program_id(axis=1)
345
+ pid_col = tl.program_id(axis=2)
346
+
347
+ # Get position of current sparsity block consisting of its batch and row index
348
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
349
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
350
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
351
+
352
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
353
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
354
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
355
+
356
+ # Get reverse sparsity indices for s
357
+ rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
358
+ spa_row * s_l_y_r_s)
359
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
360
+ rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
361
+
362
+ if rev_idx_spa_s == -1:
363
+ assert False, "Invalid sparsity block"
364
+
365
+ # Load x block
366
+ blk_x_idx = ((pid_blk * x_b_s) +
367
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
368
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
369
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
370
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
371
+
372
+ # Load sum block
373
+ blk_s_idx = (rev_idx_spa_s * y_b_s +
374
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
375
+ (tl.arange(0, 1) * y_c_s)[None, :])
376
+ blk_s_msk = (blk_s_idx < y_b * y_b_s)
377
+ blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
378
+
379
+ # Compute exp
380
+ buf = blk_x + tl.broadcast_to(blk_s, (TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE))
381
+
382
+ # debug
383
+ asdf = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1.0, dtype=tl.float32)
384
+
385
+ # Store block
386
+ blk_o_idx = ((pid_blk * o_b_s) +
387
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
388
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
389
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
390
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
blksprs/ops/conversion.py CHANGED
@@ -28,6 +28,8 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
28
28
  Tensor: The block-sparse tensor converted to regular form.
29
29
 
30
30
  """
31
+ x = x.contiguous()
32
+
31
33
  validate_dimensions(x)
32
34
  validate_contiguous(x, sparsity_layout)
33
35
  validate_device(x)
@@ -156,6 +158,8 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
156
158
  Tensor: The block-sparse tensor converted to compressed form.
157
159
 
158
160
  """
161
+ x = x.contiguous()
162
+
159
163
  validate_dimensions(x)
160
164
  validate_contiguous(x)
161
165
  validate_device(x)
@@ -282,6 +286,8 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
282
286
  Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
283
287
 
284
288
  """
289
+ x = x.contiguous()
290
+
285
291
  validate_dimensions(x)
286
292
  validate_contiguous(x, sparsity_layout_from)
287
293
  validate_device(x)
@@ -24,6 +24,9 @@ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layou
24
24
  Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
25
 
26
26
  """
27
+ src = src.contiguous()
28
+ idx = idx.contiguous()
29
+
27
30
  validate_dimensions(src, idx)
28
31
  validate_contiguous(src, idx)
29
32
  validate_dtype_int(idx)
@@ -200,6 +203,9 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
200
203
  Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
201
204
 
202
205
  """
206
+ src = src.contiguous()
207
+ idx = idx.contiguous()
208
+
203
209
  validate_dimensions(src, idx)
204
210
  validate_contiguous(src, idx)
205
211
  validate_dtype_int(idx)
blksprs/ops/exp.py CHANGED
@@ -25,6 +25,8 @@ def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> T
25
25
  compressed form.
26
26
 
27
27
  """
28
+ x = x.contiguous()
29
+
28
30
  validate_dimensions(x)
29
31
  validate_contiguous(x)
30
32
  validate_device(x)
blksprs/ops/matmul.py CHANGED
@@ -6,7 +6,7 @@ from triton import language as tl
6
6
  from blksprs.ops.transpose import transpose
7
7
  from blksprs.utils.tools import get_triton_block_size
8
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
10
10
 
11
11
 
12
12
  def matmul(x: Tensor, sparsity_layout_x: Tensor,
@@ -30,8 +30,12 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
30
30
  Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
31
31
 
32
32
  """
33
+ x = x.contiguous()
34
+ y = y.contiguous()
35
+
33
36
  validate_dimensions(x, y)
34
37
  validate_contiguous(x, y)
38
+ validate_dtype_float(x, y)
35
39
  validate_device(x, y)
36
40
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x), (y, sparsity_layout_y))
37
41
  if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
@@ -211,7 +215,7 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
211
215
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
212
216
 
213
217
  # Perform matrix multiplication
214
- buf += tl.dot(blk_x, blk_y)
218
+ buf += tl.dot(blk_x, blk_y, input_precision="tf32")
215
219
 
216
220
  # Store output
217
221
  blk_o_idx = ((pid_blk * o_b_s) +
blksprs/ops/softmax.py CHANGED
@@ -4,7 +4,7 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.ops.exp import exp
7
- from blksprs.ops.row_wise_sum import row_wise_sum
7
+ from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
8
  from blksprs.utils.tools import get_triton_block_size
9
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
10
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
@@ -26,6 +26,8 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
26
26
  Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
27
27
 
28
28
  """
29
+ x = x.contiguous()
30
+
29
31
  validate_dimensions(x)
30
32
  validate_contiguous(x)
31
33
  validate_device(x)
@@ -33,12 +35,6 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
33
35
  validate_sparsity_block_size(sparsity_block_size, x)
34
36
  validate_triton_block_size(triton_block_size, sparsity_block_size)
35
37
 
36
- if x.size(0) != 0:
37
- max_val = torch.max(x).item()
38
- else:
39
- max_val = 0
40
- x_scaled = x - max_val
41
-
42
38
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
43
39
 
44
40
  sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
@@ -49,7 +45,7 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
49
45
 
50
46
  validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
51
47
 
52
- return _BlocksparseSoftmax.apply(x_scaled, sparsity_layout,
48
+ return _BlocksparseSoftmax.apply(x, sparsity_layout,
53
49
  sparsity_lut,
54
50
  sparsity_reverse_lut_rws,
55
51
  sparsity_block_size, triton_block_size)
@@ -64,13 +60,17 @@ class _BlocksparseSoftmax(torch.autograd.Function):
64
60
  sparsity_block_size: int, triton_block_size: int) -> Tensor:
65
61
  output = torch.empty_like(x)
66
62
 
67
- x_b, x_r, x_c = x.shape
63
+ x_b, x_r, x_c = x.size()
68
64
  x_b_s, x_r_s, x_c_s = x.stride()
69
- s_lut_r, s_lut_c = sparsity_lut.shape
65
+ s_lut_r, s_lut_c = sparsity_lut.size()
70
66
  s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
71
- o_b, o_r, o_c = output.shape
67
+ o_b, o_r, o_c = output.size()
72
68
 
73
- x_exp = exp(x, sparsity_block_size, triton_block_size=triton_block_size)
69
+ x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
70
+ flag_slice_only=True,
71
+ triton_block_size=triton_block_size)
72
+ x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
73
+ x_exp = exp(x_scaled, sparsity_block_size, triton_block_size=triton_block_size)
74
74
  x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
75
75
  flag_slice_only=True,
76
76
  triton_block_size=triton_block_size)
@@ -174,7 +174,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
174
174
  spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
175
175
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
176
176
 
177
- # Get reverse sparsity indices for x
177
+ # Get reverse sparsity indices for s
178
178
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
179
179
  spa_row * s_l_s_r_s)
180
180
  rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
blksprs/ops/transpose.py CHANGED
@@ -26,6 +26,8 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
26
26
  Tensor: The sparsity layout of the transposed tensor.
27
27
 
28
28
  """
29
+ x = x.contiguous()
30
+
29
31
  validate_dimensions(x)
30
32
  validate_contiguous(x)
31
33
  validate_device(x)
blksprs/utils/tools.py CHANGED
@@ -1,10 +1,12 @@
1
1
  import torch
2
2
  from torch import Tensor, Size
3
3
 
4
+ from blksprs.utils.validation import _set_skip_validation
5
+
4
6
 
5
7
  def do_shape_blocksparse(x: Tensor):
6
8
  if x.dim() == 3:
7
- return x, x.size()
9
+ return x.contiguous(), x.size()
8
10
 
9
11
  return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
10
12
 
@@ -18,3 +20,7 @@ def undo_shape_blocksparse(x: Tensor, shape: Size):
18
20
 
19
21
  def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
20
22
  return min(sparsity_block_size, limit)
23
+
24
+
25
+ def disable_validation():
26
+ _set_skip_validation(True)
@@ -1,9 +1,10 @@
1
1
  import torch
2
2
  from torch import Tensor
3
3
 
4
+ VALIDATION = True
4
5
 
5
6
  def validate_dimensions(*tensors: Tensor) -> None:
6
- if _skip_validation():
7
+ if _check_skip_validation():
7
8
  return
8
9
 
9
10
  for tensor in tensors:
@@ -12,7 +13,7 @@ def validate_dimensions(*tensors: Tensor) -> None:
12
13
 
13
14
 
14
15
  def validate_contiguous(*tensors: Tensor) -> None:
15
- if _skip_validation():
16
+ if _check_skip_validation():
16
17
  return
17
18
 
18
19
  for tensor in tensors:
@@ -21,7 +22,7 @@ def validate_contiguous(*tensors: Tensor) -> None:
21
22
 
22
23
 
23
24
  def validate_dtype_float(*tensors: Tensor) -> None:
24
- if _skip_validation():
25
+ if _check_skip_validation():
25
26
  return
26
27
 
27
28
  for tensor in tensors:
@@ -30,7 +31,7 @@ def validate_dtype_float(*tensors: Tensor) -> None:
30
31
 
31
32
 
32
33
  def validate_dtype_int(*tensors: Tensor) -> None:
33
- if _skip_validation():
34
+ if _check_skip_validation():
34
35
  return
35
36
 
36
37
  for tensor in tensors:
@@ -39,7 +40,7 @@ def validate_dtype_int(*tensors: Tensor) -> None:
39
40
 
40
41
 
41
42
  def validate_device(*tensors: Tensor) -> None:
42
- if _skip_validation():
43
+ if _check_skip_validation():
43
44
  return
44
45
 
45
46
  device = None
@@ -56,7 +57,7 @@ def validate_device(*tensors: Tensor) -> None:
56
57
 
57
58
 
58
59
  def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
59
- if _skip_validation():
60
+ if _check_skip_validation():
60
61
  return
61
62
 
62
63
  for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
@@ -73,7 +74,7 @@ def _validate_sparsity_layout_values(sparsity_layout: Tensor):
73
74
  raise ValueError("Sparsity layout values must be either 0 or 1")
74
75
 
75
76
  def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
76
- if _skip_validation():
77
+ if _check_skip_validation():
77
78
  return
78
79
 
79
80
  if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
@@ -84,7 +85,7 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
84
85
  raise ValueError("Tensor sizes must be divisible by sparsity block size")
85
86
 
86
87
  def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
87
- if _skip_validation():
88
+ if _check_skip_validation():
88
89
  return
89
90
 
90
91
  if triton_block_size is None:
@@ -93,5 +94,9 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
93
94
  if triton_block_size > sparsity_block_size:
94
95
  raise ValueError("Triton block size cannot be larger than sparsity block size")
95
96
 
96
- def _skip_validation():
97
- return False
97
+ def _check_skip_validation():
98
+ return not VALIDATION
99
+
100
+ def _set_skip_validation(skip_validation: bool):
101
+ global VALIDATION
102
+ VALIDATION = not skip_validation
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.3
3
+ Version: 1.4.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -8,10 +8,8 @@ Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
8
  Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
10
  Requires-Dist: torch
11
- Provides-Extra: deploy
12
- Requires-Dist: build; extra == "deploy"
13
- Requires-Dist: twine; extra == "deploy"
14
- Requires-Dist: pdoc3; extra == "deploy"
11
+ Provides-Extra: build
12
+ Requires-Dist: build; extra == "build"
15
13
  Provides-Extra: test
16
14
  Requires-Dist: pytest; extra == "test"
17
15
  Requires-Dist: pytest-xdist; extra == "test"
@@ -83,14 +81,7 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
83
81
 
84
82
  ```python
85
83
  import torch
86
-
87
- from blksprs.layouting.sparsity_layout import build_sparsity_layout
88
- from blksprs.ops.conversion import to_sparse, to_dense
89
- from blksprs.ops.matmul import matmul
90
- from blksprs.ops.row_wise_sum import row_wise_sum
91
- from blksprs.ops.softmax import softmax
92
- from blksprs.ops.transpose import transpose
93
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
84
+ import blksprs as bs
94
85
 
95
86
 
96
87
  def test_readme():
@@ -112,47 +103,57 @@ def test_readme():
112
103
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
113
104
 
114
105
  # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
115
- x_dense, x_shape_original = do_shape_blocksparse(x)
116
- y_dense, y_shape_original = do_shape_blocksparse(y)
106
+ x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
107
+ y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
117
108
 
118
109
  # Create sparsity layouts from existing tensors
119
- sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
120
- sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
110
+ sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
111
+ triton_block_size=triton_block_size)
112
+ sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
113
+ triton_block_size=triton_block_size)
121
114
 
122
115
  # Create random sparsity layout for output tensor
123
116
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
124
117
 
125
118
  # Convert tensors to sparse tensors for matrix multiplication
126
- x_sparse = to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
127
- y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
119
+ x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
120
+ y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
128
121
 
129
122
  # Perform matrix multiplication
130
- o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
131
- triton_block_size=triton_block_size)
132
- o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
123
+ o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
124
+ sparsity_block_size,
125
+ triton_block_size=triton_block_size)
126
+
127
+ # Apply element-wise operation
128
+ o_sparse = torch.add(o_sparse, 1)
129
+
130
+ o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
133
131
 
134
132
  # Sanity check
135
133
  o_torch = torch.matmul(x_dense, y_dense)
134
+ o_torch = torch.add(o_torch, 1)
136
135
 
137
136
  # Perform round trip to set sparse blocks to 0
138
- o_torch_round_trip = to_dense(
139
- to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
137
+ o_torch_round_trip = bs.to_dense(
138
+ bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
140
139
  sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
141
140
 
142
141
  # Assert that the output is correct
143
142
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
144
143
 
145
144
  # Assert that the output has the correct sparsity layout
146
- actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
147
- assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
145
+ actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
146
+ triton_block_size=triton_block_size)
147
+ assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
148
148
 
149
149
  # Convert output tensor back to original shape
150
- o = undo_shape_blocksparse(o_dense, x_shape_original)
150
+ o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
151
151
 
152
152
  # Other available functions
153
- transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
154
- softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
155
- row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
153
+ bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
154
+ bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
155
+ bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
156
+ bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
156
157
 
157
158
 
158
159
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -0,0 +1,19 @@
1
+ blksprs/__init__.py,sha256=ORAVhGR91G1wyIOs9Wecv-xfmjju3bJ4Jynq_SGOVY4,833
2
+ blksprs/layouting/distribution_layout.py,sha256=Xd8KjZwI87L9EL1Bw5SGUW9YztFD5q0Ygr99sffvdak,4939
3
+ blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
4
+ blksprs/misc/broadcast_ops.py,sha256=RTcqvx6X_THRBb55jipeEe63YSLIAh27jdpuze0aSek,5308
5
+ blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
6
+ blksprs/misc/row_wise.py,sha256=KCDO5ry5TkjI88LLD_QINZwBkzfmjoQpOOvYLfpUn5I,16853
7
+ blksprs/ops/conversion.py,sha256=h1c5T74rQjqYgY9dwWXfPTXRpgzy0dtAhCmtUp8-6uo,21332
8
+ blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
9
+ blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
10
+ blksprs/ops/matmul.py,sha256=6DaYxecJgwiW8L-UISkgyNyzQ31AAkmDL-Oq1EjHt98,11210
11
+ blksprs/ops/softmax.py,sha256=cSTxDnNmMRlJGOlCSpdg1U5KUIFpVtHulz8fteJFeh0,11972
12
+ blksprs/ops/transpose.py,sha256=et8R124L29TUqihci18ms_hBoYXTtPu5LXgEA8sxk_w,6744
13
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
14
+ blksprs/utils/tools.py,sha256=RKGWCGd5h1qFOIoShsdJObx4-QsS0RxCyzFie0geNxo,596
15
+ blksprs/utils/validation.py,sha256=Gsx3aah6355bWXRPpbFuZ1p0fOrYduIqaM3ON9d5NiI,3197
16
+ blksprs-1.4.1.dist-info/METADATA,sha256=3xRmBFHv2U2KnrW3_QX3003SHLkQ1JCaSqh4AUBsJD4,7609
17
+ blksprs-1.4.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
18
+ blksprs-1.4.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
19
+ blksprs-1.4.1.dist-info/RECORD,,
@@ -1,231 +0,0 @@
1
- import torch
2
- import triton
3
- from torch import Tensor
4
- from triton import language as tl
5
-
6
- from blksprs.utils.tools import get_triton_block_size
7
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
-
10
-
11
- def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
12
- flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
13
- """Computes the row-wise sum of a block-sparse tensor.
14
-
15
- Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
16
- of the corresponding row.
17
-
18
- Note:
19
- If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
20
-
21
- Args:
22
- x (Tensor): A block-sparse tensor in compressed form.
23
- sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
- sparsity_block_size (int): The size of the sparsity blocks.
25
- flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
26
- (default ``False``).
27
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
28
-
29
- Returns:
30
- tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
- of the input and the sparsity layout of the output tensor.
32
-
33
- """
34
- validate_dimensions(x)
35
- validate_contiguous(x)
36
- validate_device(x)
37
- validate_sparsity(sparsity_block_size, (x, sparsity_layout))
38
- validate_sparsity_block_size(sparsity_block_size, x)
39
- validate_triton_block_size(triton_block_size, sparsity_block_size)
40
-
41
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
42
- sparsity_layout_flat = sparsity_layout.reshape(-1)
43
- sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
44
- (sparsity_layout_flat == 1) -
45
- (1 * (sparsity_layout_flat == 0)))
46
-
47
- sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
48
- sparsity_lut_output = torch.nonzero(sparsity_layout_output).contiguous()
49
- sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
50
- sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
51
- (sparsity_layout_output_flat == 1) -
52
- (1 * (sparsity_layout_output_flat == 0)))
53
-
54
- n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
55
-
56
- validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut,
57
- sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output)
58
-
59
- return (_BlocksparseRowWiseSum.apply(x,
60
- sparsity_layout, sparsity_lut, sparsity_reverse_lut,
61
- sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output,
62
- n_sparse_blocks_output,
63
- flag_slice_only,
64
- sparsity_block_size, triton_block_size),
65
- sparsity_layout_output)
66
-
67
-
68
- class _BlocksparseRowWiseSum(torch.autograd.Function):
69
- IMPLEMENTATION = "atomic_add"
70
-
71
- @staticmethod
72
- def forward(ctx, x: Tensor,
73
- sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
74
- sparsity_layout_output: Tensor, sparsity_lut_output: Tensor, sparsity_reverse_lut_output: Tensor,
75
- n_sparse_blocks_output: int,
76
- flag_slice_only: bool,
77
- sparsity_block_size: int, triton_block_size: int) -> Tensor:
78
- output = torch.zeros(size=(n_sparse_blocks_output,
79
- sparsity_block_size,
80
- 1 if flag_slice_only else sparsity_block_size),
81
- device=x.device)
82
-
83
- x_b, x_r, x_c = x.size()
84
- x_b_s, x_r_s, x_c_s = x.stride()
85
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout.size()
86
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout.stride()
87
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
88
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
89
- o_b, o_r, o_c = output.size()
90
- o_b_s, o_r_s, o_c_s = output.stride()
91
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
92
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
93
- s_lut_o_r, s_lut_o_c = sparsity_lut_output.size()
94
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_output.stride()
95
-
96
- if triton_block_size is None:
97
- triton_block_size = get_triton_block_size(sparsity_block_size)
98
-
99
- if _BlocksparseRowWiseSum.IMPLEMENTATION == "basic":
100
- triton_grid = lambda meta: [o_b,
101
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"])]
102
-
103
- (_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum[triton_grid]
104
- (x,
105
- x_b, x_b_s, x_r_s, x_c_s,
106
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
107
- sparsity_reverse_lut,
108
- output,
109
- o_b, o_b_s, o_r_s,
110
- sparsity_lut_output, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
111
- sparsity_block_size,
112
- triton_block_size))
113
- elif _BlocksparseRowWiseSum.IMPLEMENTATION == "atomic_add":
114
- triton_grid = lambda meta: [x_b,
115
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
116
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
117
-
118
- (_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum_atomic_add[triton_grid]
119
- (x,
120
- x_b, x_b_s, x_r_s, x_c_s,
121
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
122
- output,
123
- o_b, o_b_s, o_r_s,
124
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
125
- sparsity_reverse_lut_output,
126
- triton_block_size))
127
-
128
- return output
129
-
130
- @staticmethod
131
- def backward(ctx, grad_output):
132
- raise NotImplementedError
133
-
134
- @staticmethod
135
- @triton.jit
136
- def kernel_blocksparse_row_wise_sum(x,
137
- x_b, x_b_s, x_r_s, x_c_s,
138
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
139
- r_lut_x,
140
- o,
141
- o_b, o_b_s, o_r_s,
142
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
143
- sparsity_block_size,
144
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
145
- pid_blk = tl.program_id(axis=0)
146
- pid_row = tl.program_id(axis=1)
147
-
148
- # Get position of current sparsity block consisting of its batch and row index
149
- spa_bat_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
150
- spa_bat_msk = (spa_bat_idx < s_lut_o_r * s_lut_o_r_s)
151
- spa_bat = tl.load(s_lut_o + spa_bat_idx, mask=spa_bat_msk)
152
-
153
- spa_row_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
154
- spa_row_msk = (spa_row_idx < s_lut_o_r * s_lut_o_r_s)
155
- spa_row = tl.load(s_lut_o + spa_row_idx, mask=spa_row_msk)
156
-
157
- buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, 1), dtype=tl.float32)
158
-
159
- # Slide over triton block sized segments of input tensor
160
- for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
161
- # Convert to segment index of sparsity layout
162
- i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
163
- # Calculate the triton segment index within a block
164
- i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
165
-
166
- # Load reverse sparsity index for current block
167
- rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
168
- spa_row * s_l_x_r_s +
169
- i_seg_spa * s_l_x_c_s)
170
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
171
- rev_idx_spa = tl.load(r_lut_x + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
172
-
173
- # If block is present commence operations
174
- if rev_idx_spa >= 0:
175
- blk_idx = ((rev_idx_spa * x_b_s) +
176
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
177
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
178
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
179
- blk_msk = (blk_idx < x_b * x_b_s)
180
- blk = tl.load(x + blk_idx, mask=blk_msk)
181
-
182
- buf = buf + tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
183
-
184
- o_idx = (pid_blk * o_b_s +
185
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
186
- (tl.arange(0, 1))[None, :])
187
- o_msk = (o_idx < o_b * o_b_s)
188
- tl.store(o + o_idx, buf, o_msk)
189
-
190
- @staticmethod
191
- @triton.jit
192
- def kernel_blocksparse_row_wise_sum_atomic_add(x,
193
- x_b, x_b_s, x_r_s, x_c_s,
194
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
195
- o,
196
- o_b, o_b_s, o_r_s,
197
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
198
- r_lut_o,
199
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
200
- pid_blk = tl.program_id(axis=0)
201
- pid_row = tl.program_id(axis=1)
202
- pid_col = tl.program_id(axis=2)
203
-
204
- # Get position of current sparsity block consisting of its batch and row index
205
- spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
206
- spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
207
- spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
208
-
209
- spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
210
- spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
211
- spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
212
-
213
- # Load reverse sparsity index for current block
214
- rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
215
- spa_row * s_l_o_r_s)
216
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
217
- rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
218
-
219
- blk_idx = ((pid_blk * x_b_s) +
220
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
221
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
222
- blk_msk = (blk_idx < x_b * x_b_s)
223
- blk = tl.load(x + blk_idx, mask=blk_msk)
224
-
225
- buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
226
-
227
- o_idx = (rev_idx_spa * o_b_s +
228
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
229
- (tl.arange(0, 1))[None, :])
230
- o_msk = (o_idx < o_b * o_b_s)
231
- tl.atomic_add(o + o_idx, buf, o_msk)
@@ -1,18 +0,0 @@
1
- blksprs/layouting/distribution_layout.py,sha256=GQ-ZRXbeImiLcbaqnL2FuUZ6DoFwmB0naT_YrOpD84Q,4940
2
- blksprs/layouting/sparsity_layout.py,sha256=TtADT_WWcZpW3zyGy6KAgkAo44gDryXZqdJLZGEX2V8,7895
3
- blksprs/misc/broadcast_addition.py,sha256=vf1Hdqz9Uyqykto3DCjmdyepMzpMXL238SpANQqRAwI,5297
4
- blksprs/misc/repeat_interleave.py,sha256=WrIp7uJsnvjIhFeLYPfkL2j5vXyKmDQGrJ69b3Y0lQ8,5644
5
- blksprs/ops/conversion.py,sha256=-AOzj_j3WrBLGIgd2oVPvYS8XKfzlvGtSIWzW_qP1lk,21260
6
- blksprs/ops/distribution.py,sha256=_fQb6fWpLxocAh86D74ATahChi0EK0eBb4eUOUEBVps,16769
7
- blksprs/ops/exp.py,sha256=qs8fVtCzxl4CKT4GepaqurjEL62jyi8VjMY12JFrFAU,3674
8
- blksprs/ops/matmul.py,sha256=x3lrYg4g8fIf5PeMtZY_SEpi11kP9RFcRoemCIxcSDE,11086
9
- blksprs/ops/row_wise_sum.py,sha256=ojuSejV37cLtRNS3lBfknA5KY3TEg8EHxOqVT6JZzoM,11387
10
- blksprs/ops/softmax.py,sha256=ZyeAVqmG_VzJ72FArGrpUSFfoSM4GPxyubrmNKERVIA,11654
11
- blksprs/ops/transpose.py,sha256=cX_E3b-QMhsUDNn9D8HVkYesc2JBc-EcVBUZfCWExM8,6720
12
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
13
- blksprs/utils/tools.py,sha256=DwophH01AeNTZAo0B1uWbKFSGBQjI5z0WmFnYKh-BBk,465
14
- blksprs/utils/validation.py,sha256=gJYZO5C48YUrXV3Fy_Z_lCaOpiFj951FT-Od7sKfprg,3007
15
- blksprs-1.3.dist-info/METADATA,sha256=bs4_e4DjSYyAQ354tLVNIKcGLkww_-C2AfHnJIMdjA8,7515
16
- blksprs-1.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
17
- blksprs-1.3.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
18
- blksprs-1.3.dist-info/RECORD,,
File without changes