blksprs 1.4.2__tar.gz → 1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {blksprs-1.4.2 → blksprs-1.5}/PKG-INFO +1 -1
  2. {blksprs-1.4.2 → blksprs-1.5}/blksprs/__init__.py +4 -1
  3. blksprs-1.5/blksprs/experimental/distribution_mdi.py +438 -0
  4. {blksprs-1.4.2 → blksprs-1.5}/blksprs/layouting/distribution_layout.py +4 -17
  5. {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/PKG-INFO +1 -1
  6. {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/SOURCES.txt +1 -0
  7. {blksprs-1.4.2 → blksprs-1.5}/pyproject.toml +1 -1
  8. {blksprs-1.4.2 → blksprs-1.5}/README.md +0 -0
  9. {blksprs-1.4.2 → blksprs-1.5}/blksprs/layouting/sparsity_layout.py +0 -0
  10. {blksprs-1.4.2 → blksprs-1.5}/blksprs/misc/broadcast_ops.py +0 -0
  11. {blksprs-1.4.2 → blksprs-1.5}/blksprs/misc/repeat_interleave.py +0 -0
  12. {blksprs-1.4.2 → blksprs-1.5}/blksprs/misc/row_wise.py +0 -0
  13. {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/conversion.py +0 -0
  14. {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/distribution.py +0 -0
  15. {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/exp.py +0 -0
  16. {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/matmul.py +0 -0
  17. {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/softmax.py +0 -0
  18. {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/transpose.py +0 -0
  19. {blksprs-1.4.2 → blksprs-1.5}/blksprs/utils/benchmarking.py +0 -0
  20. {blksprs-1.4.2 → blksprs-1.5}/blksprs/utils/tools.py +0 -0
  21. {blksprs-1.4.2 → blksprs-1.5}/blksprs/utils/validation.py +0 -0
  22. {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/dependency_links.txt +0 -0
  23. {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/requires.txt +0 -0
  24. {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/top_level.txt +0 -0
  25. {blksprs-1.4.2 → blksprs-1.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.4.2
3
+ Version: 1.5
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -15,4 +15,7 @@ class misc:
15
15
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
16
16
 
17
17
  class util:
18
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
18
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
19
+
20
+ class experimental:
21
+ from blksprs.experimental.distribution_mdi import gather_mdi
@@ -0,0 +1,438 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
12
+ idx_bat: Tensor,
13
+ idx_row: Tensor,
14
+ idx_col: Tensor,
15
+ sparsity_layout_idx: Tensor,
16
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
17
+ src = src.contiguous()
18
+ idx_bat = idx_bat.contiguous()
19
+ idx_col = idx_col.contiguous()
20
+
21
+ validate_dimensions(src, idx_bat, idx_col)
22
+ validate_contiguous(src, idx_bat, idx_col)
23
+ validate_dtype_int(idx_bat, idx_col)
24
+ validate_device(src, idx_bat, idx_col)
25
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
26
+ (idx_bat, sparsity_layout_idx), (idx_col, sparsity_layout_idx))
27
+ validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
28
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
29
+
30
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
31
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
32
+ (sparsity_layout_x_flat == 1) -
33
+ (1 * (sparsity_layout_x_flat == 0)))
34
+
35
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
36
+
37
+ validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
38
+ sparsity_layout_idx, sparsity_lut_i)
39
+
40
+ return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
41
+ idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
42
+ sparsity_block_size, triton_block_size)
43
+
44
+
45
+ class _BlocksparseGatherMDI(torch.autograd.Function):
46
+
47
+ @staticmethod
48
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
49
+ idx_bat: Tensor, idx_col: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
50
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
51
+ output = torch.empty_like(idx_col, dtype=x.dtype)
52
+
53
+ x_b, x_r, x_c = x.size()
54
+ x_b_s, x_r_s, x_c_s = x.stride()
55
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
56
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
57
+ i_b, i_r, i_c = idx_col.size()
58
+ i_b_s, i_r_s, i_c_s = idx_col.stride()
59
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
60
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
61
+ o_b, o_r, o_c = output.size()
62
+ o_b_s, o_r_s, o_c_s = output.stride()
63
+
64
+ if triton_block_size is None:
65
+ triton_block_size = get_triton_block_size(sparsity_block_size)
66
+
67
+ triton_grid = lambda meta: [o_b,
68
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
69
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
70
+
71
+ (_BlocksparseGatherMDI.kernel_blocksparse_gather_mdi[triton_grid]
72
+ (x,
73
+ x_b, x_b_s, x_r_s, x_c_s,
74
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
75
+ sparsity_reverse_lut_x,
76
+ idx_bat,
77
+ idx_col,
78
+ i_b, i_b_s, i_r_s, i_c_s,
79
+ output,
80
+ o_b, o_b_s, o_r_s, o_c_s,
81
+ sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
82
+ sparsity_block_size,
83
+ triton_block_size))
84
+
85
+ ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i)
86
+ ctx.sparsity_block_size = sparsity_block_size
87
+ ctx.triton_block_size = triton_block_size
88
+
89
+ return output
90
+
91
+ @staticmethod
92
+ def backward(ctx, grad_output):
93
+ sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i = ctx.saved_tensors
94
+ sparsity_block_size = ctx.sparsity_block_size
95
+ triton_block_size = ctx.triton_block_size
96
+
97
+ return scatter_reduce_mdi(grad_output, sparsity_layout_i,
98
+ idx_bat,
99
+ None,
100
+ idx_col,
101
+ sparsity_layout_x,
102
+ sparsity_block_size,
103
+ reduce_op="sum",
104
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
105
+
106
+ @staticmethod
107
+ @triton.jit
108
+ def kernel_blocksparse_gather_mdi(x,
109
+ x_b, x_b_s, x_r_s, x_c_s,
110
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
111
+ r_lut_x,
112
+ idx_bat,
113
+ idx_col,
114
+ i_b, i_b_s, i_r_s, i_c_s,
115
+ o,
116
+ o_b, o_b_s, o_r_s, o_c_s,
117
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
118
+ sparsity_block_size,
119
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
120
+ # Get triton block indices
121
+ pid_blk = tl.program_id(axis=0)
122
+ pid_row = tl.program_id(axis=1)
123
+ pid_col = tl.program_id(axis=2)
124
+
125
+ # Load batch index values
126
+ blk_idx_bat_idx = ((pid_blk * i_b_s) +
127
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
128
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
129
+ blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
130
+ blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
131
+
132
+ # Get position of current sparsity block row
133
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
134
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
135
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
136
+
137
+ # Load column index values
138
+ blk_idx_col_idx = ((pid_blk * i_b_s) +
139
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
140
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
141
+ blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
142
+ blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
143
+
144
+ # Get positions of sparsity blocks
145
+ pos_spa_blk_x = blk_idx_col // sparsity_block_size
146
+ pos_spa_col_x = blk_idx_col % sparsity_block_size
147
+
148
+ # Load reverse sparsity indices for x
149
+ rev_idx_spa_x_idx = ((blk_idx_bat * s_l_x_b_s) +
150
+ (spa_row_o * s_l_x_r_s) +
151
+ (pos_spa_blk_x * s_l_x_c_s))
152
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
153
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
154
+
155
+ # Load x values
156
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
157
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
158
+ (pos_spa_col_x * x_c_s))
159
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
160
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
161
+
162
+ # Store output
163
+ blk_o_idx = ((pid_blk * o_b_s) +
164
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
165
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
166
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
167
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
168
+
169
+
170
+ def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
171
+ idx_bat: Tensor,
172
+ idx_row: Tensor,
173
+ idx_col: Tensor,
174
+ sparsity_layout_tgt: Tensor,
175
+ sparsity_block_size: int,
176
+ reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
177
+ src = src.contiguous()
178
+ idx_bat = idx_bat.contiguous()
179
+ idx_col = idx_col.contiguous()
180
+
181
+ validate_dimensions(src, idx_bat, idx_col)
182
+ validate_contiguous(src, idx_bat, idx_col)
183
+ validate_dtype_int(idx_bat, idx_col)
184
+ validate_device(src, idx_bat, idx_col)
185
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
186
+ (idx_bat, sparsity_layout_src),
187
+ (idx_col, sparsity_layout_src))
188
+ validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
189
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
190
+
191
+ if reduce_op not in ["none", "sum"]:
192
+ raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
193
+
194
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
195
+
196
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
197
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
198
+ (sparsity_layout_o_flat == 1) -
199
+ (1 * (sparsity_layout_o_flat == 0)))
200
+
201
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
202
+
203
+ validate_contiguous(sparsity_layout_src, sparsity_lut_x,
204
+ sparsity_layout_tgt, sparsity_reverse_lut_o)
205
+
206
+ return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
207
+ idx_bat,
208
+ idx_col,
209
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
210
+ sparsity_block_size, n_sparse_blocks,
211
+ reduce_op, triton_block_size)
212
+
213
+
214
+ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
215
+
216
+ @staticmethod
217
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
218
+ idx_bat: Tensor,
219
+ idx_col: Tensor,
220
+ sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
221
+ sparsity_block_size: int, n_sparse_blocks: int,
222
+ reduce_op: str, triton_block_size: int) -> Tensor:
223
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
224
+ dtype=x.dtype, device=x.device)
225
+
226
+ x_b, x_r, x_c = x.size()
227
+ x_b_s, x_r_s, x_c_s = x.stride()
228
+ s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
229
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
230
+ i_b, i_r, i_c = idx_col.size()
231
+ i_b_s, i_r_s, i_c_s = idx_col.stride()
232
+ o_b, o_r, o_c = output.size()
233
+ o_b_s, o_r_s, o_c_s = output.stride()
234
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
235
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
236
+
237
+ if triton_block_size is None:
238
+ triton_block_size = get_triton_block_size(sparsity_block_size)
239
+
240
+ triton_grid = lambda meta: [x_b,
241
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
242
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
243
+
244
+ reduce_op_ind = 0
245
+ if reduce_op == "sum":
246
+ reduce_op_ind = 1
247
+
248
+ (_BlocksparseScatterReduceMDI.kernel_blocksparse_scatter_mdi[triton_grid]
249
+ (x,
250
+ x_b, x_b_s, x_r_s, x_c_s,
251
+ sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
252
+ idx_bat,
253
+ idx_col,
254
+ i_b, i_b_s, i_r_s, i_c_s,
255
+ output,
256
+ o_b, o_b_s, o_r_s, o_c_s,
257
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
258
+ sparsity_reverse_lut_o,
259
+ reduce_op_ind,
260
+ sparsity_block_size,
261
+ triton_block_size))
262
+
263
+ ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o)
264
+ ctx.sparsity_block_size = sparsity_block_size
265
+ ctx.reduce_op = reduce_op
266
+ ctx.triton_block_size = triton_block_size
267
+
268
+ return output
269
+
270
+ @staticmethod
271
+ def backward(ctx, grad_output):
272
+ sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o = ctx.saved_tensors
273
+ sparsity_block_size = ctx.sparsity_block_size
274
+ reduce_op = ctx.reduce_op
275
+ triton_block_size = ctx.triton_block_size
276
+
277
+ if reduce_op == "sum":
278
+ return gather_mdi(grad_output, sparsity_layout_o,
279
+ idx_bat,
280
+ None,
281
+ idx_col,
282
+ sparsity_layout_x, sparsity_block_size,
283
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
284
+ else:
285
+ raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
286
+
287
+ @staticmethod
288
+ @triton.jit
289
+ def kernel_blocksparse_scatter_mdi(x,
290
+ x_b, x_b_s, x_r_s, x_c_s,
291
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
292
+ idx_bat,
293
+ idx_col,
294
+ i_b, i_b_s, i_r_s, i_c_s,
295
+ o,
296
+ o_b, o_b_s, o_r_s, o_c_s,
297
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
298
+ r_lut_o,
299
+ reduce_op_ind,
300
+ sparsity_block_size,
301
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
302
+ # Get triton block indices
303
+ pid_blk = tl.program_id(axis=0)
304
+ pid_row = tl.program_id(axis=1)
305
+ pid_col = tl.program_id(axis=2)
306
+
307
+ # Load batch index values
308
+ blk_idx_bat_idx = ((pid_blk * i_b_s) +
309
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
310
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
311
+ blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
312
+ blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
313
+
314
+ # Get position of current sparsity block row
315
+ spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
316
+ spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
317
+ spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
318
+
319
+ # Load x values
320
+ blk_x_idx = ((pid_blk * x_b_s) +
321
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
322
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
323
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
324
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
325
+
326
+ # Load column index values
327
+ blk_idx_col_idx = ((pid_blk * i_b_s) +
328
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
329
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
330
+ blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
331
+ blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
332
+
333
+ # Get positions of sparsity blocks
334
+ pos_spa_blk_o = blk_idx_col // sparsity_block_size
335
+ pos_spa_col_o = blk_idx_col % sparsity_block_size
336
+
337
+ # Load reverse sparsity indices for o
338
+ rev_idx_spa_o_idx = ((blk_idx_bat * s_l_o_b_s) +
339
+ (spa_row_x * s_l_o_r_s) +
340
+ (pos_spa_blk_o * s_l_o_c_s))
341
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
342
+ rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
343
+
344
+ # Store output
345
+ blk_o_idx = ((rev_idx_spa_o * o_b_s) +
346
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
347
+ (pos_spa_col_o * o_c_s))
348
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
349
+
350
+ if reduce_op_ind == 0:
351
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
352
+ elif reduce_op_ind == 1:
353
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
354
+
355
+
356
+ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Tensor, sparsity_layout_idx: Tensor,
357
+ size_target: torch.Size,
358
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
359
+ validate_dimensions(idx_bat, idx_col)
360
+ validate_contiguous(idx_bat, idx_col)
361
+ validate_device(idx_bat, idx_col)
362
+
363
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
364
+
365
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
366
+ dtype=torch.bool, device=idx_col.device)
367
+
368
+ i_b, i_r, i_c = idx_col.size()
369
+ i_b_s, i_r_s, i_c_s = idx_col.stride()
370
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
371
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
372
+ o_b, o_r, o_c = output.size()
373
+ o_b_s, o_r_s, o_c_s = output.stride()
374
+
375
+ if triton_block_size is None:
376
+ triton_block_size = get_triton_block_size(sparsity_block_size)
377
+
378
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
379
+
380
+ triton_grid = lambda meta: [i_b,
381
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
382
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
383
+
384
+ (kernel_distribution_layout_mdi[triton_grid]
385
+ (idx_bat,
386
+ idx_col,
387
+ i_b, i_b_s, i_r_s, i_c_s,
388
+ sparsity_lut_i,
389
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
390
+ output,
391
+ o_b, o_b_s, o_r_s, o_c_s,
392
+ sparsity_block_size,
393
+ triton_block_size))
394
+
395
+ return output
396
+
397
+
398
+ @triton.jit
399
+ def kernel_distribution_layout_mdi(idx_bat,
400
+ idx_col,
401
+ i_b, i_b_s, i_r_s, i_c_s,
402
+ s_lut_i,
403
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
404
+ o,
405
+ o_b, o_b_s, o_r_s, o_c_s,
406
+ sparsity_block_size,
407
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
408
+ # Get triton block indices
409
+ pid_blk = tl.program_id(axis=0)
410
+ pid_row = tl.program_id(axis=1)
411
+ pid_col = tl.program_id(axis=2)
412
+
413
+ # Load batch index values
414
+ blk_idx_bat_idx = ((pid_blk * i_b_s) +
415
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
416
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
417
+ blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
418
+ blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
419
+
420
+ # Get position of current sparsity block row
421
+ spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
422
+ spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
423
+ spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
424
+
425
+ blk_i_idx = (pid_blk * i_b_s +
426
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
427
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
428
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
429
+ blk_i = tl.load(idx_col + blk_i_idx, mask=blk_i_msk)
430
+
431
+ blk_i = blk_i // sparsity_block_size
432
+ blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
433
+
434
+ blk_o_idx = ((blk_idx_bat * o_b_s) +
435
+ (spa_row_i * o_r_s) +
436
+ (blk_i * o_c_s))
437
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
438
+ tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -35,8 +35,6 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
35
35
 
36
36
  i_b, i_r, i_c = indices.size()
37
37
  i_b_s, i_r_s, i_c_s = indices.stride()
38
- s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
39
- s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
40
38
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
41
39
  s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
42
40
  o_b, o_r, o_c = output.size()
@@ -54,12 +52,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
54
52
  (kernel_distribution_layout[triton_grid]
55
53
  (indices,
56
54
  i_b, i_b_s, i_r_s, i_c_s,
57
- sparsity_layout_indices,
58
- s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
59
55
  sparsity_lut_i,
60
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
56
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
61
57
  output,
62
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
58
+ o_b, o_b_s, o_r_s, o_c_s,
63
59
  sparsity_block_size,
64
60
  triton_block_size))
65
61
 
@@ -69,12 +65,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
69
65
  @triton.jit
70
66
  def kernel_distribution_layout(i,
71
67
  i_b, i_b_s, i_r_s, i_c_s,
72
- s_l_i,
73
- s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
74
68
  s_lut_i,
75
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
69
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
76
70
  o,
77
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
71
+ o_b, o_b_s, o_r_s, o_c_s,
78
72
  sparsity_block_size,
79
73
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
74
  # Get triton block indices
@@ -105,10 +99,3 @@ def kernel_distribution_layout(i,
105
99
  (blk_i * o_c_s))
106
100
  blk_o_msk = (blk_o_idx < o_b * o_b_s)
107
101
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
108
-
109
- # if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
110
- # blk_o_idx = (pid_bat * o_b_s +
111
- # (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
112
- # ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
113
- # blk_o_msk = (blk_o_idx < o_b * o_b_s)
114
- # tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.4.2
3
+ Version: 1.5
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -6,6 +6,7 @@ blksprs.egg-info/SOURCES.txt
6
6
  blksprs.egg-info/dependency_links.txt
7
7
  blksprs.egg-info/requires.txt
8
8
  blksprs.egg-info/top_level.txt
9
+ blksprs/experimental/distribution_mdi.py
9
10
  blksprs/layouting/distribution_layout.py
10
11
  blksprs/layouting/sparsity_layout.py
11
12
  blksprs/misc/broadcast_ops.py
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "1.4.2"
3
+ version = "1.5"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes