blksprs 1.4.1__py3-none-any.whl → 1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -15,4 +15,7 @@ class misc:
15
15
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
16
16
 
17
17
  class util:
18
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
18
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
19
+
20
+ class experimental:
21
+ from blksprs.experimental.distribution_mdi import gather_mdi
@@ -0,0 +1,438 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
12
+ idx_bat: Tensor,
13
+ idx_row: Tensor,
14
+ idx_col: Tensor,
15
+ sparsity_layout_idx: Tensor,
16
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
17
+ src = src.contiguous()
18
+ idx_bat = idx_bat.contiguous()
19
+ idx_col = idx_col.contiguous()
20
+
21
+ validate_dimensions(src, idx_bat, idx_col)
22
+ validate_contiguous(src, idx_bat, idx_col)
23
+ validate_dtype_int(idx_bat, idx_col)
24
+ validate_device(src, idx_bat, idx_col)
25
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
26
+ (idx_bat, sparsity_layout_idx), (idx_col, sparsity_layout_idx))
27
+ validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
28
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
29
+
30
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
31
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
32
+ (sparsity_layout_x_flat == 1) -
33
+ (1 * (sparsity_layout_x_flat == 0)))
34
+
35
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
36
+
37
+ validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
38
+ sparsity_layout_idx, sparsity_lut_i)
39
+
40
+ return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
41
+ idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
42
+ sparsity_block_size, triton_block_size)
43
+
44
+
45
+ class _BlocksparseGatherMDI(torch.autograd.Function):
46
+
47
+ @staticmethod
48
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
49
+ idx_bat: Tensor, idx_col: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
50
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
51
+ output = torch.empty_like(idx_col, dtype=x.dtype)
52
+
53
+ x_b, x_r, x_c = x.size()
54
+ x_b_s, x_r_s, x_c_s = x.stride()
55
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
56
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
57
+ i_b, i_r, i_c = idx_col.size()
58
+ i_b_s, i_r_s, i_c_s = idx_col.stride()
59
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
60
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
61
+ o_b, o_r, o_c = output.size()
62
+ o_b_s, o_r_s, o_c_s = output.stride()
63
+
64
+ if triton_block_size is None:
65
+ triton_block_size = get_triton_block_size(sparsity_block_size)
66
+
67
+ triton_grid = lambda meta: [o_b,
68
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
69
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
70
+
71
+ (_BlocksparseGatherMDI.kernel_blocksparse_gather_mdi[triton_grid]
72
+ (x,
73
+ x_b, x_b_s, x_r_s, x_c_s,
74
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
75
+ sparsity_reverse_lut_x,
76
+ idx_bat,
77
+ idx_col,
78
+ i_b, i_b_s, i_r_s, i_c_s,
79
+ output,
80
+ o_b, o_b_s, o_r_s, o_c_s,
81
+ sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
82
+ sparsity_block_size,
83
+ triton_block_size))
84
+
85
+ ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i)
86
+ ctx.sparsity_block_size = sparsity_block_size
87
+ ctx.triton_block_size = triton_block_size
88
+
89
+ return output
90
+
91
+ @staticmethod
92
+ def backward(ctx, grad_output):
93
+ sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i = ctx.saved_tensors
94
+ sparsity_block_size = ctx.sparsity_block_size
95
+ triton_block_size = ctx.triton_block_size
96
+
97
+ return scatter_reduce_mdi(grad_output, sparsity_layout_i,
98
+ idx_bat,
99
+ None,
100
+ idx_col,
101
+ sparsity_layout_x,
102
+ sparsity_block_size,
103
+ reduce_op="sum",
104
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
105
+
106
+ @staticmethod
107
+ @triton.jit
108
+ def kernel_blocksparse_gather_mdi(x,
109
+ x_b, x_b_s, x_r_s, x_c_s,
110
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
111
+ r_lut_x,
112
+ idx_bat,
113
+ idx_col,
114
+ i_b, i_b_s, i_r_s, i_c_s,
115
+ o,
116
+ o_b, o_b_s, o_r_s, o_c_s,
117
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
118
+ sparsity_block_size,
119
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
120
+ # Get triton block indices
121
+ pid_blk = tl.program_id(axis=0)
122
+ pid_row = tl.program_id(axis=1)
123
+ pid_col = tl.program_id(axis=2)
124
+
125
+ # Load batch index values
126
+ blk_idx_bat_idx = ((pid_blk * i_b_s) +
127
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
128
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
129
+ blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
130
+ blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
131
+
132
+ # Get position of current sparsity block row
133
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
134
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
135
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
136
+
137
+ # Load column index values
138
+ blk_idx_col_idx = ((pid_blk * i_b_s) +
139
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
140
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
141
+ blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
142
+ blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
143
+
144
+ # Get positions of sparsity blocks
145
+ pos_spa_blk_x = blk_idx_col // sparsity_block_size
146
+ pos_spa_col_x = blk_idx_col % sparsity_block_size
147
+
148
+ # Load reverse sparsity indices for x
149
+ rev_idx_spa_x_idx = ((blk_idx_bat * s_l_x_b_s) +
150
+ (spa_row_o * s_l_x_r_s) +
151
+ (pos_spa_blk_x * s_l_x_c_s))
152
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
153
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
154
+
155
+ # Load x values
156
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
157
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
158
+ (pos_spa_col_x * x_c_s))
159
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
160
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
161
+
162
+ # Store output
163
+ blk_o_idx = ((pid_blk * o_b_s) +
164
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
165
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
166
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
167
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
168
+
169
+
170
+ def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
171
+ idx_bat: Tensor,
172
+ idx_row: Tensor,
173
+ idx_col: Tensor,
174
+ sparsity_layout_tgt: Tensor,
175
+ sparsity_block_size: int,
176
+ reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
177
+ src = src.contiguous()
178
+ idx_bat = idx_bat.contiguous()
179
+ idx_col = idx_col.contiguous()
180
+
181
+ validate_dimensions(src, idx_bat, idx_col)
182
+ validate_contiguous(src, idx_bat, idx_col)
183
+ validate_dtype_int(idx_bat, idx_col)
184
+ validate_device(src, idx_bat, idx_col)
185
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
186
+ (idx_bat, sparsity_layout_src),
187
+ (idx_col, sparsity_layout_src))
188
+ validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
189
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
190
+
191
+ if reduce_op not in ["none", "sum"]:
192
+ raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
193
+
194
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
195
+
196
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
197
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
198
+ (sparsity_layout_o_flat == 1) -
199
+ (1 * (sparsity_layout_o_flat == 0)))
200
+
201
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
202
+
203
+ validate_contiguous(sparsity_layout_src, sparsity_lut_x,
204
+ sparsity_layout_tgt, sparsity_reverse_lut_o)
205
+
206
+ return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
207
+ idx_bat,
208
+ idx_col,
209
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
210
+ sparsity_block_size, n_sparse_blocks,
211
+ reduce_op, triton_block_size)
212
+
213
+
214
+ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
215
+
216
+ @staticmethod
217
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
218
+ idx_bat: Tensor,
219
+ idx_col: Tensor,
220
+ sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
221
+ sparsity_block_size: int, n_sparse_blocks: int,
222
+ reduce_op: str, triton_block_size: int) -> Tensor:
223
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
224
+ dtype=x.dtype, device=x.device)
225
+
226
+ x_b, x_r, x_c = x.size()
227
+ x_b_s, x_r_s, x_c_s = x.stride()
228
+ s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
229
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
230
+ i_b, i_r, i_c = idx_col.size()
231
+ i_b_s, i_r_s, i_c_s = idx_col.stride()
232
+ o_b, o_r, o_c = output.size()
233
+ o_b_s, o_r_s, o_c_s = output.stride()
234
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
235
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
236
+
237
+ if triton_block_size is None:
238
+ triton_block_size = get_triton_block_size(sparsity_block_size)
239
+
240
+ triton_grid = lambda meta: [x_b,
241
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
242
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
243
+
244
+ reduce_op_ind = 0
245
+ if reduce_op == "sum":
246
+ reduce_op_ind = 1
247
+
248
+ (_BlocksparseScatterReduceMDI.kernel_blocksparse_scatter_mdi[triton_grid]
249
+ (x,
250
+ x_b, x_b_s, x_r_s, x_c_s,
251
+ sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
252
+ idx_bat,
253
+ idx_col,
254
+ i_b, i_b_s, i_r_s, i_c_s,
255
+ output,
256
+ o_b, o_b_s, o_r_s, o_c_s,
257
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
258
+ sparsity_reverse_lut_o,
259
+ reduce_op_ind,
260
+ sparsity_block_size,
261
+ triton_block_size))
262
+
263
+ ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o)
264
+ ctx.sparsity_block_size = sparsity_block_size
265
+ ctx.reduce_op = reduce_op
266
+ ctx.triton_block_size = triton_block_size
267
+
268
+ return output
269
+
270
+ @staticmethod
271
+ def backward(ctx, grad_output):
272
+ sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o = ctx.saved_tensors
273
+ sparsity_block_size = ctx.sparsity_block_size
274
+ reduce_op = ctx.reduce_op
275
+ triton_block_size = ctx.triton_block_size
276
+
277
+ if reduce_op == "sum":
278
+ return gather_mdi(grad_output, sparsity_layout_o,
279
+ idx_bat,
280
+ None,
281
+ idx_col,
282
+ sparsity_layout_x, sparsity_block_size,
283
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
284
+ else:
285
+ raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
286
+
287
+ @staticmethod
288
+ @triton.jit
289
+ def kernel_blocksparse_scatter_mdi(x,
290
+ x_b, x_b_s, x_r_s, x_c_s,
291
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
292
+ idx_bat,
293
+ idx_col,
294
+ i_b, i_b_s, i_r_s, i_c_s,
295
+ o,
296
+ o_b, o_b_s, o_r_s, o_c_s,
297
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
298
+ r_lut_o,
299
+ reduce_op_ind,
300
+ sparsity_block_size,
301
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
302
+ # Get triton block indices
303
+ pid_blk = tl.program_id(axis=0)
304
+ pid_row = tl.program_id(axis=1)
305
+ pid_col = tl.program_id(axis=2)
306
+
307
+ # Load batch index values
308
+ blk_idx_bat_idx = ((pid_blk * i_b_s) +
309
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
310
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
311
+ blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
312
+ blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
313
+
314
+ # Get position of current sparsity block row
315
+ spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
316
+ spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
317
+ spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
318
+
319
+ # Load x values
320
+ blk_x_idx = ((pid_blk * x_b_s) +
321
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
322
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
323
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
324
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
325
+
326
+ # Load column index values
327
+ blk_idx_col_idx = ((pid_blk * i_b_s) +
328
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
329
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
330
+ blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
331
+ blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
332
+
333
+ # Get positions of sparsity blocks
334
+ pos_spa_blk_o = blk_idx_col // sparsity_block_size
335
+ pos_spa_col_o = blk_idx_col % sparsity_block_size
336
+
337
+ # Load reverse sparsity indices for o
338
+ rev_idx_spa_o_idx = ((blk_idx_bat * s_l_o_b_s) +
339
+ (spa_row_x * s_l_o_r_s) +
340
+ (pos_spa_blk_o * s_l_o_c_s))
341
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
342
+ rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
343
+
344
+ # Store output
345
+ blk_o_idx = ((rev_idx_spa_o * o_b_s) +
346
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
347
+ (pos_spa_col_o * o_c_s))
348
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
349
+
350
+ if reduce_op_ind == 0:
351
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
352
+ elif reduce_op_ind == 1:
353
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
354
+
355
+
356
+ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Tensor, sparsity_layout_idx: Tensor,
357
+ size_target: torch.Size,
358
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
359
+ validate_dimensions(idx_bat, idx_col)
360
+ validate_contiguous(idx_bat, idx_col)
361
+ validate_device(idx_bat, idx_col)
362
+
363
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
364
+
365
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
366
+ dtype=torch.bool, device=idx_col.device)
367
+
368
+ i_b, i_r, i_c = idx_col.size()
369
+ i_b_s, i_r_s, i_c_s = idx_col.stride()
370
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
371
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
372
+ o_b, o_r, o_c = output.size()
373
+ o_b_s, o_r_s, o_c_s = output.stride()
374
+
375
+ if triton_block_size is None:
376
+ triton_block_size = get_triton_block_size(sparsity_block_size)
377
+
378
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
379
+
380
+ triton_grid = lambda meta: [i_b,
381
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
382
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
383
+
384
+ (kernel_distribution_layout_mdi[triton_grid]
385
+ (idx_bat,
386
+ idx_col,
387
+ i_b, i_b_s, i_r_s, i_c_s,
388
+ sparsity_lut_i,
389
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
390
+ output,
391
+ o_b, o_b_s, o_r_s, o_c_s,
392
+ sparsity_block_size,
393
+ triton_block_size))
394
+
395
+ return output
396
+
397
+
398
+ @triton.jit
399
+ def kernel_distribution_layout_mdi(idx_bat,
400
+ idx_col,
401
+ i_b, i_b_s, i_r_s, i_c_s,
402
+ s_lut_i,
403
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
404
+ o,
405
+ o_b, o_b_s, o_r_s, o_c_s,
406
+ sparsity_block_size,
407
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
408
+ # Get triton block indices
409
+ pid_blk = tl.program_id(axis=0)
410
+ pid_row = tl.program_id(axis=1)
411
+ pid_col = tl.program_id(axis=2)
412
+
413
+ # Load batch index values
414
+ blk_idx_bat_idx = ((pid_blk * i_b_s) +
415
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
416
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
417
+ blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
418
+ blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
419
+
420
+ # Get position of current sparsity block row
421
+ spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
422
+ spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
423
+ spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
424
+
425
+ blk_i_idx = (pid_blk * i_b_s +
426
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
427
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
428
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
429
+ blk_i = tl.load(idx_col + blk_i_idx, mask=blk_i_msk)
430
+
431
+ blk_i = blk_i // sparsity_block_size
432
+ blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
433
+
434
+ blk_o_idx = ((blk_idx_bat * o_b_s) +
435
+ (spa_row_i * o_r_s) +
436
+ (blk_i * o_c_s))
437
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
438
+ tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -35,8 +35,6 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
35
35
 
36
36
  i_b, i_r, i_c = indices.size()
37
37
  i_b_s, i_r_s, i_c_s = indices.stride()
38
- s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
39
- s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
40
38
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
41
39
  s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
42
40
  o_b, o_r, o_c = output.size()
@@ -54,12 +52,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
54
52
  (kernel_distribution_layout[triton_grid]
55
53
  (indices,
56
54
  i_b, i_b_s, i_r_s, i_c_s,
57
- sparsity_layout_indices,
58
- s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
59
55
  sparsity_lut_i,
60
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
56
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
61
57
  output,
62
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
58
+ o_b, o_b_s, o_r_s, o_c_s,
63
59
  sparsity_block_size,
64
60
  triton_block_size))
65
61
 
@@ -69,12 +65,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
69
65
  @triton.jit
70
66
  def kernel_distribution_layout(i,
71
67
  i_b, i_b_s, i_r_s, i_c_s,
72
- s_l_i,
73
- s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
74
68
  s_lut_i,
75
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
69
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
76
70
  o,
77
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
71
+ o_b, o_b_s, o_r_s, o_c_s,
78
72
  sparsity_block_size,
79
73
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
74
  # Get triton block indices
@@ -105,10 +99,3 @@ def kernel_distribution_layout(i,
105
99
  (blk_i * o_c_s))
106
100
  blk_o_msk = (blk_o_idx < o_b * o_b_s)
107
101
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
108
-
109
- # if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
110
- # blk_o_idx = (pid_bat * o_b_s +
111
- # (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
112
- # ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
113
- # blk_o_msk = (blk_o_idx < o_b * o_b_s)
114
- # tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -41,7 +41,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
41
41
 
42
42
  validate_contiguous(sparsity_layout_output, sparsity_lut_o)
43
43
 
44
- output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
44
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
45
45
 
46
46
  x_b, x_c = x.size()
47
47
  x_b_s, x_c_s = x.stride()
blksprs/misc/row_wise.py CHANGED
@@ -56,6 +56,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
56
56
  output = torch.zeros(size=(n_sparse_blocks_output,
57
57
  sparsity_block_size,
58
58
  1 if flag_slice_only else sparsity_block_size),
59
+ dtype=x.dtype,
59
60
  device=x.device)
60
61
 
61
62
  x_b, x_r, x_c = x.size()
blksprs/ops/conversion.py CHANGED
@@ -186,8 +186,8 @@ class _BlocksparseToSparse(torch.autograd.Function):
186
186
  def forward(ctx, x: Tensor,
187
187
  sparsity_layout: Tensor, sparsity_lut: Tensor,
188
188
  sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
189
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), dtype=x.dtype,
190
- device=x.device)
189
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
190
+ dtype=x.dtype, device=x.device)
191
191
 
192
192
  x_b, x_r, x_c = x.size()
193
193
  x_b_s, x_r_s, x_c_s = x.stride()
blksprs/ops/matmul.py CHANGED
@@ -78,7 +78,8 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
78
78
  sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
79
79
  sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
80
80
  sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
81
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
81
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
82
+ dtype=x.dtype, device=x.device)
82
83
 
83
84
  x_b, x_r, x_c = x.size()
84
85
  x_b_s, x_r_s, x_c_s = x.stride()
blksprs/ops/softmax.py CHANGED
@@ -127,7 +127,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
127
127
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
128
128
  s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
129
129
 
130
- grad_x = torch.empty_like(o)
130
+ grad_x = torch.empty_like(o, dtype=torch.float)
131
131
 
132
132
  triton_grid = lambda meta: [o_b,
133
133
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
blksprs/ops/transpose.py CHANGED
@@ -59,7 +59,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
59
59
  def forward(ctx, x: Tensor,
60
60
  sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
61
61
  n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
62
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
62
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
63
+ dtype=x.dtype, device=x.device)
63
64
 
64
65
  x_b, x_r, x_c = x.size()
65
66
  x_b_s, x_r_s, x_c_s = x.stride()
@@ -101,7 +102,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
101
102
  sparsity_block_size = ctx.sparsity_block_size
102
103
  triton_block_size = ctx.triton_block_size
103
104
 
104
- return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None
105
+ return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
106
+ 0], None, None, None, None, None, None
105
107
 
106
108
  @staticmethod
107
109
  @triton.jit
blksprs/utils/tools.py CHANGED
@@ -1,4 +1,3 @@
1
- import torch
2
1
  from torch import Tensor, Size
3
2
 
4
3
  from blksprs.utils.validation import _set_skip_validation
@@ -8,7 +7,7 @@ def do_shape_blocksparse(x: Tensor):
8
7
  if x.dim() == 3:
9
8
  return x.contiguous(), x.size()
10
9
 
11
- return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
10
+ return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
12
11
 
13
12
 
14
13
  def undo_shape_blocksparse(x: Tensor, shape: Size):
@@ -3,13 +3,13 @@ from torch import Tensor
3
3
 
4
4
  VALIDATION = True
5
5
 
6
- def validate_dimensions(*tensors: Tensor) -> None:
6
+ def validate_dimensions(*tensors: Tensor, dims=3) -> None:
7
7
  if _check_skip_validation():
8
8
  return
9
9
 
10
10
  for tensor in tensors:
11
- if tensor.dim() != 3:
12
- raise ValueError("Tensor must have 3 dimensions")
11
+ if tensor.dim() != dims:
12
+ raise ValueError(f"Tensor must have {dims} dimensions")
13
13
 
14
14
 
15
15
  def validate_contiguous(*tensors: Tensor) -> None:
@@ -91,6 +91,9 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
91
91
  if triton_block_size is None:
92
92
  return
93
93
 
94
+ if not (triton_block_size & (triton_block_size - 1)) == 0:
95
+ raise ValueError("Triton block size must be a power of 2")
96
+
94
97
  if triton_block_size > sparsity_block_size:
95
98
  raise ValueError("Triton block size cannot be larger than sparsity block size")
96
99
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.4.1
3
+ Version: 1.5
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -0,0 +1,20 @@
1
+ blksprs/__init__.py,sha256=OY9ofdbzBGsvY6hx0oLCrSszlJFdMns9x7gKE0asFI0,919
2
+ blksprs/experimental/distribution_mdi.py,sha256=shu-3Nt7nkaLIb4O2kSajC8Lh7IWFXO9rsjzP14ASYA,20088
3
+ blksprs/layouting/distribution_layout.py,sha256=Zv-b2t5VOvW6-ejdX42kUV7X1yYsvDCY_PXFE_wKwi0,4165
4
+ blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
5
+ blksprs/misc/broadcast_ops.py,sha256=ahm7_lI12bJ6VTKRuSkwEeaEYWRY-BeMIOhtei35zpQ,5323
6
+ blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
7
+ blksprs/misc/row_wise.py,sha256=1UtjLplrGx1FkxhzQ2hjSBBY11ToLQs0JiLaXKRAkL4,16893
8
+ blksprs/ops/conversion.py,sha256=vuiNwrwyuGI6H4PKrS_UHI7OKWJwNZd2i3LSjf6RetU,21332
9
+ blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
10
+ blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
11
+ blksprs/ops/matmul.py,sha256=743XeD5M4iUv28sYf7q6mVXDd4jZpV04JAx8bF7hWkw,11254
12
+ blksprs/ops/softmax.py,sha256=cs1utM6UCzHhdJpf-ZysBr6CwbjI-5aQG0ahYY37Zy0,11991
13
+ blksprs/ops/transpose.py,sha256=Ru4YKyg796WT6OnDSTCYG45tMmdgvju3hMFzkwsJnO8,6801
14
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
15
+ blksprs/utils/tools.py,sha256=JAuwsLISr_hcvxIgUVvKz5ZPf9M5ycquplsBU5dVfDc,596
16
+ blksprs/utils/validation.py,sha256=rP6yr-C2ghXfJEERry_pfvVJ0g0VyqV4sL4HkBRlJg8,3345
17
+ blksprs-1.5.dist-info/METADATA,sha256=dql0_6s1Vfdnx6sLFusayZWSeU9uxvfAjBDdLPk43so,7607
18
+ blksprs-1.5.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
19
+ blksprs-1.5.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
20
+ blksprs-1.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,19 +0,0 @@
1
- blksprs/__init__.py,sha256=ORAVhGR91G1wyIOs9Wecv-xfmjju3bJ4Jynq_SGOVY4,833
2
- blksprs/layouting/distribution_layout.py,sha256=Xd8KjZwI87L9EL1Bw5SGUW9YztFD5q0Ygr99sffvdak,4939
3
- blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
4
- blksprs/misc/broadcast_ops.py,sha256=RTcqvx6X_THRBb55jipeEe63YSLIAh27jdpuze0aSek,5308
5
- blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
6
- blksprs/misc/row_wise.py,sha256=KCDO5ry5TkjI88LLD_QINZwBkzfmjoQpOOvYLfpUn5I,16853
7
- blksprs/ops/conversion.py,sha256=h1c5T74rQjqYgY9dwWXfPTXRpgzy0dtAhCmtUp8-6uo,21332
8
- blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
9
- blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
10
- blksprs/ops/matmul.py,sha256=6DaYxecJgwiW8L-UISkgyNyzQ31AAkmDL-Oq1EjHt98,11210
11
- blksprs/ops/softmax.py,sha256=cSTxDnNmMRlJGOlCSpdg1U5KUIFpVtHulz8fteJFeh0,11972
12
- blksprs/ops/transpose.py,sha256=et8R124L29TUqihci18ms_hBoYXTtPu5LXgEA8sxk_w,6744
13
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
14
- blksprs/utils/tools.py,sha256=RKGWCGd5h1qFOIoShsdJObx4-QsS0RxCyzFie0geNxo,596
15
- blksprs/utils/validation.py,sha256=Gsx3aah6355bWXRPpbFuZ1p0fOrYduIqaM3ON9d5NiI,3197
16
- blksprs-1.4.1.dist-info/METADATA,sha256=3xRmBFHv2U2KnrW3_QX3003SHLkQ1JCaSqh4AUBsJD4,7609
17
- blksprs-1.4.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
18
- blksprs-1.4.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
19
- blksprs-1.4.1.dist-info/RECORD,,