blksprs 1.9.2__py3-none-any.whl → 1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,447 +0,0 @@
1
- import torch
2
- import triton
3
- from torch import Tensor
4
- from triton import language as tl
5
-
6
- from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
- validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
10
-
11
-
12
- def gather_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
13
- idx_bat: BlksprsTensor,
14
- idx_row: BlksprsTensor,
15
- idx_col: BlksprsTensor,
16
- sparsity_layout_idx: Tensor,
17
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
18
- src = src.contiguous()
19
- idx_bat = idx_bat.contiguous()
20
- idx_col = idx_col.contiguous()
21
-
22
- validate_dimensions(src, idx_bat, idx_col)
23
- validate_contiguous(src, idx_bat, idx_col)
24
- validate_dtype_int(idx_bat, idx_col)
25
- validate_device(src, idx_bat, idx_col)
26
- validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
27
- (idx_bat, sparsity_layout_idx), (idx_col, sparsity_layout_idx))
28
- validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
29
- validate_triton_block_size(triton_block_size, sparsity_block_size)
30
-
31
- sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
32
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
33
- (sparsity_layout_x_flat == 1) -
34
- (1 * (sparsity_layout_x_flat == 0)))
35
-
36
- sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
37
-
38
- validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
39
- sparsity_layout_idx, sparsity_lut_i)
40
-
41
- return BlksprsTensor(_BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
42
- idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
43
- sparsity_block_size, triton_block_size))
44
-
45
-
46
- class _BlocksparseGatherMDI(torch.autograd.Function):
47
-
48
- @staticmethod
49
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
50
- idx_bat: Tensor, idx_col: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
51
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
52
- output = torch.empty_like(idx_col, dtype=x.dtype)
53
-
54
- x_b, x_r, x_c = x.size()
55
- x_b_s, x_r_s, x_c_s = stride(x)
56
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
57
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
58
- i_b, i_r, i_c = idx_col.size()
59
- i_b_s, i_r_s, i_c_s = stride(idx_col)
60
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
61
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
62
- o_b, o_r, o_c = output.size()
63
- o_b_s, o_r_s, o_c_s = stride(output)
64
-
65
- if triton_block_size is None:
66
- triton_block_size = get_triton_block_size(sparsity_block_size)
67
-
68
- triton_grid = lambda meta: [o_b,
69
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
70
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
71
-
72
- (_BlocksparseGatherMDI.kernel_blocksparse_gather_mdi[triton_grid]
73
- (x,
74
- x_b, x_b_s, x_r_s, x_c_s,
75
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
76
- sparsity_reverse_lut_x,
77
- idx_bat,
78
- idx_col,
79
- i_b, i_b_s, i_r_s, i_c_s,
80
- output,
81
- o_b, o_b_s, o_r_s, o_c_s,
82
- sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
83
- sparsity_block_size,
84
- triton_block_size))
85
-
86
- ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i)
87
- ctx.sparsity_block_size = sparsity_block_size
88
- ctx.triton_block_size = triton_block_size
89
-
90
- return output
91
-
92
- @staticmethod
93
- def backward(ctx, grad_output):
94
- sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i = ctx.saved_tensors
95
- sparsity_block_size = ctx.sparsity_block_size
96
- triton_block_size = ctx.triton_block_size
97
-
98
- return scatter_reduce_mdi(grad_output, sparsity_layout_i,
99
- idx_bat,
100
- None,
101
- idx_col,
102
- sparsity_layout_x,
103
- sparsity_block_size,
104
- reduce_op="sum",
105
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
106
-
107
- @staticmethod
108
- @triton.jit
109
- def kernel_blocksparse_gather_mdi(x,
110
- x_b, x_b_s, x_r_s, x_c_s,
111
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
112
- r_lut_x,
113
- idx_bat,
114
- idx_col,
115
- i_b, i_b_s, i_r_s, i_c_s,
116
- o,
117
- o_b, o_b_s, o_r_s, o_c_s,
118
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
119
- sparsity_block_size,
120
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
121
- # Get triton block indices
122
- pid_blk = tl.program_id(axis=0)
123
- pid_row = tl.program_id(axis=1)
124
- pid_col = tl.program_id(axis=2)
125
-
126
- # Load batch index values
127
- blk_idx_bat_idx = ((pid_blk * i_b_s) +
128
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
129
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
130
- blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
131
- blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
132
-
133
- # Get position of current sparsity block row
134
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
135
- spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
136
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
137
-
138
- # Load column index values
139
- blk_idx_col_idx = ((pid_blk * i_b_s) +
140
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
141
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
142
- blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
143
- blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
144
-
145
- # Get positions of sparsity blocks
146
- pos_spa_blk_x = blk_idx_col // sparsity_block_size
147
- pos_spa_col_x = blk_idx_col % sparsity_block_size
148
-
149
- # Load reverse sparsity indices for x
150
- rev_idx_spa_x_idx = ((blk_idx_bat * s_l_x_b_s) +
151
- (spa_row_o * s_l_x_r_s) +
152
- (pos_spa_blk_x * s_l_x_c_s))
153
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
154
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
155
-
156
- if rev_idx_spa_x == -1:
157
- tl.device_assert(False)
158
- return
159
-
160
- # Load x values
161
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
162
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
163
- (pos_spa_col_x * x_c_s))
164
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
165
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
166
-
167
- # Store output
168
- blk_o_idx = ((pid_blk * o_b_s) +
169
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
170
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
171
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
172
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
173
-
174
-
175
- def scatter_reduce_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
176
- idx_bat: BlksprsTensor,
177
- idx_row: BlksprsTensor,
178
- idx_col: BlksprsTensor,
179
- sparsity_layout_tgt: Tensor,
180
- sparsity_block_size: int,
181
- reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
182
- src = src.contiguous()
183
- idx_bat = idx_bat.contiguous()
184
- idx_col = idx_col.contiguous()
185
-
186
- validate_dimensions(src, idx_bat, idx_col)
187
- validate_contiguous(src, idx_bat, idx_col)
188
- validate_dtype_int(idx_bat, idx_col)
189
- validate_device(src, idx_bat, idx_col)
190
- validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
191
- (idx_bat, sparsity_layout_src),
192
- (idx_col, sparsity_layout_src))
193
- validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
194
- validate_triton_block_size(triton_block_size, sparsity_block_size)
195
-
196
- if reduce_op not in ["none", "sum"]:
197
- raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
198
-
199
- sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
200
-
201
- sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
202
- sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
203
- (sparsity_layout_o_flat == 1) -
204
- (1 * (sparsity_layout_o_flat == 0)))
205
-
206
- n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
207
-
208
- validate_contiguous(sparsity_layout_src, sparsity_lut_x,
209
- sparsity_layout_tgt, sparsity_reverse_lut_o)
210
-
211
- return BlksprsTensor(_BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
212
- idx_bat,
213
- idx_col,
214
- sparsity_layout_tgt, sparsity_reverse_lut_o,
215
- sparsity_block_size, n_sparse_blocks,
216
- reduce_op, triton_block_size))
217
-
218
-
219
- class _BlocksparseScatterReduceMDI(torch.autograd.Function):
220
-
221
- @staticmethod
222
- def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
223
- idx_bat: Tensor,
224
- idx_col: Tensor,
225
- sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
226
- sparsity_block_size: int, n_sparse_blocks: int,
227
- reduce_op: str, triton_block_size: int) -> Tensor:
228
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
229
- dtype=x.dtype, device=x.device)
230
-
231
- x_b, x_r, x_c = x.size()
232
- x_b_s, x_r_s, x_c_s = stride(x)
233
- s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
234
- s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
235
- i_b, i_r, i_c = idx_col.size()
236
- i_b_s, i_r_s, i_c_s = stride(idx_col)
237
- o_b, o_r, o_c = output.size()
238
- o_b_s, o_r_s, o_c_s = stride(output)
239
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
240
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
241
-
242
- if triton_block_size is None:
243
- triton_block_size = get_triton_block_size(sparsity_block_size)
244
-
245
- triton_grid = lambda meta: [x_b,
246
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
247
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
248
-
249
- reduce_op_ind = 0
250
- if reduce_op == "sum":
251
- reduce_op_ind = 1
252
-
253
- (_BlocksparseScatterReduceMDI.kernel_blocksparse_scatter_mdi[triton_grid]
254
- (x,
255
- x_b, x_b_s, x_r_s, x_c_s,
256
- sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
257
- idx_bat,
258
- idx_col,
259
- i_b, i_b_s, i_r_s, i_c_s,
260
- output,
261
- o_b, o_b_s, o_r_s, o_c_s,
262
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
263
- sparsity_reverse_lut_o,
264
- reduce_op_ind,
265
- sparsity_block_size,
266
- triton_block_size))
267
-
268
- ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o)
269
- ctx.sparsity_block_size = sparsity_block_size
270
- ctx.reduce_op = reduce_op
271
- ctx.triton_block_size = triton_block_size
272
-
273
- return output
274
-
275
- @staticmethod
276
- def backward(ctx, grad_output):
277
- sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o = ctx.saved_tensors
278
- sparsity_block_size = ctx.sparsity_block_size
279
- reduce_op = ctx.reduce_op
280
- triton_block_size = ctx.triton_block_size
281
-
282
- if reduce_op == "sum":
283
- return gather_mdi(grad_output, sparsity_layout_o,
284
- idx_bat,
285
- None,
286
- idx_col,
287
- sparsity_layout_x, sparsity_block_size,
288
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
289
- else:
290
- raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
291
-
292
- @staticmethod
293
- @triton.jit
294
- def kernel_blocksparse_scatter_mdi(x,
295
- x_b, x_b_s, x_r_s, x_c_s,
296
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
297
- idx_bat,
298
- idx_col,
299
- i_b, i_b_s, i_r_s, i_c_s,
300
- o,
301
- o_b, o_b_s, o_r_s, o_c_s,
302
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
303
- r_lut_o,
304
- reduce_op_ind,
305
- sparsity_block_size,
306
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
307
- # Get triton block indices
308
- pid_blk = tl.program_id(axis=0)
309
- pid_row = tl.program_id(axis=1)
310
- pid_col = tl.program_id(axis=2)
311
-
312
- # Load batch index values
313
- blk_idx_bat_idx = ((pid_blk * i_b_s) +
314
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
315
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
316
- blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
317
- blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
318
-
319
- # Get position of current sparsity block row
320
- spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
321
- spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
322
- spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
323
-
324
- # Load x values
325
- blk_x_idx = ((pid_blk * x_b_s) +
326
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
327
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
328
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
329
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
330
-
331
- # Load column index values
332
- blk_idx_col_idx = ((pid_blk * i_b_s) +
333
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
334
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
335
- blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
336
- blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
337
-
338
- # Get positions of sparsity blocks
339
- pos_spa_blk_o = blk_idx_col // sparsity_block_size
340
- pos_spa_col_o = blk_idx_col % sparsity_block_size
341
-
342
- # Load reverse sparsity indices for o
343
- rev_idx_spa_o_idx = ((blk_idx_bat * s_l_o_b_s) +
344
- (spa_row_x * s_l_o_r_s) +
345
- (pos_spa_blk_o * s_l_o_c_s))
346
- rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
347
- rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
348
-
349
- if rev_idx_spa_o == -1:
350
- tl.device_assert(False)
351
- return
352
-
353
- # Store output
354
- blk_o_idx = ((rev_idx_spa_o * o_b_s) +
355
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
356
- (pos_spa_col_o * o_c_s))
357
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
358
-
359
- if reduce_op_ind == 0:
360
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
361
- elif reduce_op_ind == 1:
362
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
363
-
364
-
365
- def build_distribution_layout_mdi(idx_bat: BlksprsTensor, idx_row: BlksprsTensor, idx_col: BlksprsTensor,
366
- sparsity_layout_idx: Tensor, size_target: torch.Size,
367
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
368
- validate_dimensions(idx_bat, idx_col)
369
- validate_contiguous(idx_bat, idx_col)
370
- validate_device(idx_bat, idx_col)
371
-
372
- sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
373
-
374
- output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
375
- dtype=torch.bool, device=idx_col.device)
376
-
377
- i_b, i_r, i_c = idx_col.size()
378
- i_b_s, i_r_s, i_c_s = stride(idx_col)
379
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
380
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
381
- o_b, o_r, o_c = output.size()
382
- o_b_s, o_r_s, o_c_s = stride(output)
383
-
384
- if triton_block_size is None:
385
- triton_block_size = get_triton_block_size(sparsity_block_size)
386
-
387
- validate_triton_block_size(triton_block_size, sparsity_block_size)
388
-
389
- triton_grid = lambda meta: [i_b,
390
- triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
391
- triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
392
-
393
- (kernel_distribution_layout_mdi[triton_grid]
394
- (idx_bat,
395
- idx_col,
396
- i_b, i_b_s, i_r_s, i_c_s,
397
- sparsity_lut_i,
398
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
399
- output,
400
- o_b, o_b_s, o_r_s, o_c_s,
401
- sparsity_block_size,
402
- triton_block_size))
403
-
404
- return output
405
-
406
-
407
- @triton.jit
408
- def kernel_distribution_layout_mdi(idx_bat,
409
- idx_col,
410
- i_b, i_b_s, i_r_s, i_c_s,
411
- s_lut_i,
412
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
413
- o,
414
- o_b, o_b_s, o_r_s, o_c_s,
415
- sparsity_block_size,
416
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
417
- # Get triton block indices
418
- pid_blk = tl.program_id(axis=0)
419
- pid_row = tl.program_id(axis=1)
420
- pid_col = tl.program_id(axis=2)
421
-
422
- # Load batch index values
423
- blk_idx_bat_idx = ((pid_blk * i_b_s) +
424
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
425
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
426
- blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
427
- blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
428
-
429
- # Get position of current sparsity block row
430
- spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
431
- spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
432
- spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
433
-
434
- blk_i_idx = (pid_blk * i_b_s +
435
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
436
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
437
- blk_i_msk = (blk_i_idx < i_b * i_b_s)
438
- blk_i = tl.load(idx_col + blk_i_idx, mask=blk_i_msk)
439
-
440
- blk_i = blk_i // sparsity_block_size
441
- blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
442
-
443
- blk_o_idx = ((blk_idx_bat * o_b_s) +
444
- (spa_row_i * o_r_s) +
445
- (blk_i * o_c_s))
446
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
447
- tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -1,25 +0,0 @@
1
- blksprs/__init__.py,sha256=L2wP3sFBjfcIOuI2WhQW1eUEYuKoZLKxSV9z0aQmknM,2001
2
- blksprs/layouting/distribution_layout.py,sha256=9f_Bx2YQF4LTH95C0S7OuB9eeOuh73NcE0Z7Wrtug38,5034
3
- blksprs/layouting/sparsity_layout.py,sha256=-sScIn4hhG35j9BXytrojEzp8jnFkMargJjtivPV1fc,9755
4
- blksprs/ops/conversion.py,sha256=2lQZfPd1iFheXIcoH0LbN2m7vqFRQ8XUzhGFlDckBsM,22052
5
- blksprs/ops/distribution.py,sha256=JGa-eLY-1OgicU3vPAwuhqsoUIeyadzmTk2t25aYyak,19956
6
- blksprs/ops/flow.py,sha256=RBXNOA6O0Ay2sotH8uNoltZywkdxJocJCn3bfB1fGjM,6185
7
- blksprs/ops/matmul.py,sha256=yh2ZnO0ZltT1AgadiFP0vX28YJ4n74xO-I_5vFUmOmA,11452
8
- blksprs/ops/partitioning.py,sha256=z7kx4FrC-ugxZP-IsOHCfdbsF__ld0P-vDota5CbU4s,7672
9
- blksprs/ops/repeat.py,sha256=RCa-dITomA5v12K5Oxa5_ReA361zS7WHPNNHxSp9PGw,8578
10
- blksprs/ops/softmax.py,sha256=V-1vqRefjjwSp6JPwKxVxh5pTng9gOdtgGlXHDPbpYM,12190
11
- blksprs/ops/transpose.py,sha256=jxzFFffrj4S_9tiCrwwUMdz6EA98o1dziWXjlqb64a4,6859
12
- blksprs/ops/experimental/distribution_mdi.py,sha256=F_0tl4Gn-9JZs_TZfDtZqO_RPFl7sejqQNF8UNIoCbs,20533
13
- blksprs/ops/misc/broadcast_ops.py,sha256=cPtRJa3pkZfY1QG51CJ-zDn4SK-CRpX5LEXoKGGMvRU,5418
14
- blksprs/ops/misc/exp.py,sha256=FnSFosBfJHuiEbD0MD-i4axLghRn4a0f8KvHXrKBB6M,3802
15
- blksprs/ops/misc/row_wise.py,sha256=U4Kk0-P4oOuMNjMHXxP2gP9njMIeMfz8RZrzItNIF94,17229
16
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
17
- blksprs/utils/blksprs_tensor.py,sha256=VjplBgDhnf9sxf-1R5feA0xp5FDCDdaeZmCeoIRdCnc,151
18
- blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
19
- blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
20
- blksprs/utils/tools.py,sha256=r7Y4C37vfSWUyQTGwa8NyRqgovmsq9hMufkenqYHOxo,539
21
- blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
22
- blksprs-1.9.2.dist-info/METADATA,sha256=JIHA58YnLfFrUyAOsPmHMWbDz_XmkDiXypLhg1ijO0E,8670
23
- blksprs-1.9.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
24
- blksprs-1.9.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
25
- blksprs-1.9.2.dist-info/RECORD,,