blksprs 1.10.1__py3-none-any.whl → 1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/softmax.py CHANGED
@@ -3,7 +3,6 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.ops.misc.exp import exp
7
6
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
7
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
8
  from blksprs.utils.tools import get_triton_block_size, stride
@@ -12,7 +11,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
12
11
 
13
12
 
14
13
  def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
15
- triton_block_size: int = None) -> BlksprsTensor:
14
+ triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
16
15
  """Computes the softmax of a block-sparse tensor in compressed form.
17
16
 
18
17
  Note:
@@ -23,6 +22,7 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
23
22
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
23
  sparsity_block_size (int): The size of the sparsity blocks.
25
24
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
25
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
26
26
 
27
27
  Returns:
28
28
  BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
@@ -37,24 +37,38 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
37
37
  validate_sparsity_block_size(sparsity_block_size, x)
38
38
  validate_triton_block_size(triton_block_size, sparsity_block_size)
39
39
 
40
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
41
-
42
- sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
43
- sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
44
- sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
45
- (sparsity_layout_rws_flat == 1) -
46
- (1 * (sparsity_layout_rws_flat == 0)))
47
-
48
- validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
40
+ lut = _BlocksparseSoftmax.build_lut(lut, sparsity_layout)
49
41
 
50
42
  return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
51
- sparsity_lut,
52
- sparsity_reverse_lut_rws,
43
+ lut["sparsity_lut"],
44
+ lut["sparsity_reverse_lut_rws"],
53
45
  sparsity_block_size, triton_block_size))
54
46
 
55
47
 
56
48
  class _BlocksparseSoftmax(torch.autograd.Function):
57
49
 
50
+ @staticmethod
51
+ def build_lut(lut: dict, sparsity_layout: Tensor):
52
+ if lut is None:
53
+ lut = dict()
54
+
55
+ if "sparsity_lut" not in lut:
56
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
57
+ lut["sparsity_lut"] = sparsity_lut
58
+
59
+
60
+ if "sparsity_reverse_lut_rws" not in lut:
61
+ sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
62
+ sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
63
+ sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
64
+ (sparsity_layout_rws_flat == 1) -
65
+ (1 * (sparsity_layout_rws_flat == 0)))
66
+ lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
67
+
68
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
69
+
70
+ return lut
71
+
58
72
  @staticmethod
59
73
  def forward(ctx, x: Tensor, sparsity_layout: Tensor,
60
74
  sparsity_lut: Tensor,
@@ -72,7 +86,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
72
86
  flag_slice_only=True,
73
87
  triton_block_size=triton_block_size)
74
88
  x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
75
- x_exp = exp(x_scaled, sparsity_block_size, triton_block_size=triton_block_size)
89
+ x_exp = torch.exp(x_scaled)
76
90
  x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
77
91
  flag_slice_only=True,
78
92
  triton_block_size=triton_block_size)
@@ -182,29 +196,26 @@ class _BlocksparseSoftmax(torch.autograd.Function):
182
196
  rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
183
197
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
184
198
 
185
- if rev_idx_spa_s == -1:
186
- tl.device_assert(False)
187
- return
188
-
189
- # Load x block
190
- blk_x_idx = ((pid_blk * x_b_s) +
191
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
192
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
193
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
194
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
199
+ if rev_idx_spa_s >= 0:
200
+ # Load x block
201
+ blk_x_idx = ((pid_blk * x_b_s) +
202
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
203
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
204
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
205
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
195
206
 
196
- # Load sum block
197
- blk_s_idx = (rev_idx_spa_s * s_b_s +
198
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
199
- (tl.arange(0, 1) * s_c_s)[None, :])
200
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
201
- blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
207
+ # Load sum block
208
+ blk_s_idx = (rev_idx_spa_s * s_b_s +
209
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
210
+ (tl.arange(0, 1) * s_c_s)[None, :])
211
+ blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
212
+ blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
202
213
 
203
- # Compute softmax
204
- buf = tl.div_rn(blk_x, blk_s)
214
+ # Compute softmax
215
+ buf = tl.div_rn(blk_x, blk_s)
205
216
 
206
- # Store output
207
- tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
217
+ # Store output
218
+ tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
208
219
 
209
220
  @staticmethod
210
221
  @triton.jit
@@ -239,32 +250,29 @@ class _BlocksparseSoftmax(torch.autograd.Function):
239
250
  rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
240
251
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
241
252
 
242
- if rev_idx_spa_s == -1:
243
- tl.device_assert(False)
244
- return
245
-
246
- blk_s_idx = (rev_idx_spa_s * s_b_s +
247
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
248
- (tl.arange(0, 1) * s_c_s)[None, :])
249
- blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
250
- blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
251
-
252
- blk_g_idx = ((pid_blk * g_b_s) +
253
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
254
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
255
- blk_g_msk = (blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
256
- blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
257
-
258
- blk_x_idx = ((pid_blk * x_b_s) +
259
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
260
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
261
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
262
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
263
-
264
- buf = blk_x * (blk_g - blk_s)
265
-
266
- blk_o_idx = ((pid_blk * o_b_s) +
267
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
268
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
269
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
270
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
253
+ if rev_idx_spa_s >= 0:
254
+ blk_s_idx = (rev_idx_spa_s * s_b_s +
255
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
256
+ (tl.arange(0, 1) * s_c_s)[None, :])
257
+ blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
258
+ blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
259
+
260
+ blk_g_idx = ((pid_blk * g_b_s) +
261
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
262
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
263
+ blk_g_msk = (blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
264
+ blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
265
+
266
+ blk_x_idx = ((pid_blk * x_b_s) +
267
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
268
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
269
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
270
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
271
+
272
+ buf = blk_x * (blk_g - blk_s)
273
+
274
+ blk_o_idx = ((pid_blk * o_b_s) +
275
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
276
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
277
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
278
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
blksprs/ops/transpose.py CHANGED
@@ -3,14 +3,15 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.ops.flow import flow_forward_pull
6
7
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
8
  from blksprs.utils.tools import get_triton_block_size, stride
8
9
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
9
10
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
10
11
 
11
12
 
12
- def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
13
- BlksprsTensor, Tensor):
13
+ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None,
14
+ lut: dict = None) -> (BlksprsTensor, Tensor):
14
15
  """Transposes a block-sparse tensor in compressed form.
15
16
 
16
17
  Note:
@@ -21,6 +22,7 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
21
22
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
22
23
  sparsity_block_size (int): The size of the sparsity blocks.
23
24
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
25
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
24
26
 
25
27
  Returns:
26
28
  BlksprsTensor: The transposed block-sparse tensor in compressed form.
@@ -28,6 +30,7 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
28
30
 
29
31
  """
30
32
  x = x.contiguous()
33
+ x_t = x.transpose(-1, -2).contiguous()
31
34
 
32
35
  validate_dimensions(x)
33
36
  validate_contiguous(x)
@@ -36,66 +39,53 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
36
39
  validate_sparsity_block_size(sparsity_block_size, x)
37
40
  validate_triton_block_size(triton_block_size, sparsity_block_size)
38
41
 
39
- sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
42
+ lut = _BlocksparseTranspose.build_lut(lut, sparsity_layout)
40
43
 
41
- sparsity_lut = torch.nonzero(sparsity_layout_t).contiguous()
44
+ return BlksprsTensor(
45
+ _BlocksparseTranspose.apply(x_t, lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
46
+ sparsity_block_size,
47
+ lut["n_sparse_blocks"], triton_block_size)), lut["sparsity_layout_t"]
42
48
 
43
- sparsity_layout_flat = sparsity_layout.reshape(-1)
44
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
45
- (sparsity_layout_flat == 1) -
46
- (1 * (sparsity_layout_flat == 0)))
47
- .reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
48
49
 
49
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
50
+ class _BlocksparseTranspose(torch.autograd.Function):
50
51
 
51
- validate_contiguous(sparsity_layout_t, sparsity_lut, sparsity_reverse_lut)
52
+ @staticmethod
53
+ def build_lut(lut: dict, sparsity_layout: Tensor):
54
+ if lut is None:
55
+ lut = dict()
52
56
 
53
- return BlksprsTensor(
54
- _BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
55
- n_sparse_blocks, triton_block_size)), sparsity_layout_t
57
+ if "sparsity_layout_t" not in lut:
58
+ sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
59
+ lut["sparsity_layout_t"] = sparsity_layout_t
56
60
 
61
+ if "sparsity_lut" not in lut:
62
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_t"]).contiguous()
63
+ lut["sparsity_lut"] = sparsity_lut
57
64
 
58
- class _BlocksparseTranspose(torch.autograd.Function):
65
+ if "sparsity_reverse_lut" not in lut:
66
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
67
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
68
+ (sparsity_layout_flat == 1) -
69
+ (1 * (sparsity_layout_flat == 0)))
70
+ .reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
71
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
72
+
73
+ if "n_sparse_blocks" not in lut:
74
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
75
+ lut["n_sparse_blocks"] = n_sparse_blocks
76
+
77
+ validate_contiguous(lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
78
+
79
+ return lut
59
80
 
60
81
  @staticmethod
61
82
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
62
83
  sparsity_block_size: int,
63
84
  n_sparse_blocks: int, triton_block_size: int) -> Tensor:
64
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
65
- dtype=x.dtype, device=x.device)
66
-
67
- x_b, x_r, x_c = x.size()
68
- x_b_s, x_r_s, x_c_s = stride(x)
69
- s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
70
- s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout_o)
71
- s_lut_r, s_lut_c = sparsity_lut.shape
72
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
73
- o_b, o_r, o_c = output.size()
74
- o_b_s, o_r_s, o_c_s = stride(output)
75
-
76
- if triton_block_size is None:
77
- triton_block_size = get_triton_block_size(sparsity_block_size)
78
-
79
- triton_grid = lambda meta: [o_b,
80
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
81
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
82
-
83
- (_BlocksparseTranspose.kernel_blocksparse_transpose[triton_grid]
84
- (x,
85
- x_b, x_b_s, x_r_s, x_c_s,
86
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
87
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
88
- sparsity_reverse_lut,
89
- output,
90
- o_b, o_b_s,
91
- triton_block_size))
92
-
93
- # Save for backward pass
94
85
  ctx.save_for_backward(sparsity_layout_o)
95
- ctx.sparsity_block_size = sparsity_block_size
96
- ctx.triton_block_size = triton_block_size
97
86
 
98
- return output
87
+ return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
88
+ sparsity_block_size, n_sparse_blocks, triton_block_size)
99
89
 
100
90
  @staticmethod
101
91
  def backward(ctx, grad_output):
@@ -105,56 +95,3 @@ class _BlocksparseTranspose(torch.autograd.Function):
105
95
 
106
96
  return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
107
97
  0], None, None, None, None, None, None
108
-
109
- @staticmethod
110
- @triton.jit
111
- def kernel_blocksparse_transpose(x,
112
- x_b, x_b_s, x_r_s, x_c_s,
113
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
114
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
115
- r_lut,
116
- o,
117
- o_b, o_b_s,
118
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
119
- # Get triton block indices
120
- pid_blk = tl.program_id(axis=0)
121
- pid_row = tl.program_id(axis=1)
122
- pid_col = tl.program_id(axis=2)
123
-
124
- # Get sparsity index of current output block consisting of its batch, row, and column index
125
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
126
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
127
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
128
-
129
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
130
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
131
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
132
-
133
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
134
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
135
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
136
-
137
- # Get reverse sparsity index
138
- rev_idx_spa_idx = (spa_bat * s_l_b_s +
139
- spa_row * s_l_r_s +
140
- spa_col * s_l_c_s)
141
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
142
- rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
143
-
144
- if rev_idx_spa == -1:
145
- tl.device_assert(False)
146
- return
147
-
148
- blk_x_idx = (rev_idx_spa * x_b_s +
149
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
150
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
151
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
152
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
153
-
154
- blk_x_t = tl.trans(blk_x)
155
-
156
- blk_o_idx = (pid_blk * o_b_s +
157
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
158
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
159
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
160
- tl.store(o + blk_o_idx, blk_x_t, mask=blk_o_msk)
blksprs/utils/tools.py CHANGED
@@ -1,3 +1,4 @@
1
+ import torch
1
2
  from torch import Tensor, Size
2
3
 
3
4
 
@@ -20,4 +21,9 @@ def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
20
21
 
21
22
 
22
23
  def stride(x: Tensor):
23
- return x.view(x.shape).stride()
24
+ if x.dim() == 2:
25
+ return x.size(1), 1
26
+ elif x.dim() == 3:
27
+ return x.size(1) * x.size(2), x.size(2), 1
28
+ else:
29
+ raise NotImplementedError
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: blksprs
3
- Version: 1.10.1
3
+ Version: 1.11
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -0,0 +1,23 @@
1
+ blksprs/__init__.py,sha256=AJYVfR40nOfE5F3waHPVSuajwYDcoGkiEQc8HhQbUBU,1721
2
+ blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
3
+ blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
4
+ blksprs/ops/conversion.py,sha256=QFtZ-nmY2JAWutheiO07vatXqz3eSZBP5Ym_U2Q1oWk,23299
5
+ blksprs/ops/distribution.py,sha256=nHTuE7Tq0Q404VN8bWNC2sEwmmdAtgZI6I7auRICdps,21749
6
+ blksprs/ops/flow.py,sha256=7tOXfTBKOAixYmDa_VXg7TwviLV5ZQMHQjtbyOjqA00,7879
7
+ blksprs/ops/matmul.py,sha256=eVj_BGj78bJkXYuvw4KctMfcfveQBt5OdYmeXzdpO88,12631
8
+ blksprs/ops/partitioning.py,sha256=qMv9w3yFWXwXIhIppdcJ_JMsoZ25HCH38vb6GRneoLM,10416
9
+ blksprs/ops/repeat.py,sha256=i824ijprfYpCaEjiSG5FTUZz7wMS5ksVy_-vY7ZX8Fg,9729
10
+ blksprs/ops/softmax.py,sha256=_mGkA2jHN8cXwtWXYswobEPyM7UC0JyzRszoE4ZYs7w,13063
11
+ blksprs/ops/transpose.py,sha256=O1XhGIGiVkhOSKcBD0HrYaeK6HmpvEEzLb7zJl7xsIM,4246
12
+ blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
13
+ blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
14
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
15
+ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
16
+ blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
17
+ blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
18
+ blksprs/utils/tools.py,sha256=k2OfEplbQiAwVjP84zZf7SNB8FzvMtOFBL9sC98OCbI,683
19
+ blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
20
+ blksprs-1.11.dist-info/METADATA,sha256=NUEiHexWiFNbMxQI2TUEzMw9iGBhxqflhWr2xCgOw28,9105
21
+ blksprs-1.11.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
22
+ blksprs-1.11.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-1.11.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (76.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
blksprs/ops/misc/exp.py DELETED
@@ -1,104 +0,0 @@
1
- import torch
2
- import triton
3
- from torch import Tensor
4
- from triton import language as tl
5
-
6
- from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
- validate_sparsity_block_size, validate_triton_block_size
10
-
11
-
12
- def exp(x: BlksprsTensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
13
- """Applies the element-wise exponential function to a block-sparse tensor.
14
-
15
- Note:
16
- This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
17
- Consider this when converting back to tensors in regular form.
18
-
19
- Args:
20
- x (BlksprsTensor): A block-sparse tensor in compressed form.
21
- sparsity_block_size (int): The size of the sparsity blocks.
22
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
-
24
- Returns:
25
- BlksprsTensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
26
- compressed form.
27
-
28
- """
29
- x = x.contiguous()
30
-
31
- validate_dimensions(x)
32
- validate_contiguous(x)
33
- validate_device(x)
34
- validate_sparsity_block_size(sparsity_block_size, x)
35
- validate_triton_block_size(triton_block_size, sparsity_block_size)
36
-
37
- return BlksprsTensor(_BlocksparseExp.apply(x, sparsity_block_size, triton_block_size))
38
-
39
-
40
- class _BlocksparseExp(torch.autograd.Function):
41
-
42
- @staticmethod
43
- def forward(ctx, x: Tensor, sparsity_block_size: int, triton_block_size: int) -> Tensor:
44
- output = torch.empty_like(x)
45
-
46
- x_b, x_r, x_c = x.shape
47
- x_b_s, x_r_s, x_c_s = stride(x)
48
- o_b, o_r, o_c = output.shape
49
- o_b_s, o_r_s, o_c_s = stride(output)
50
-
51
- if triton_block_size is None:
52
- triton_block_size = get_triton_block_size(sparsity_block_size)
53
-
54
- triton_grid = lambda meta: [o_b,
55
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
56
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
57
-
58
- (_BlocksparseExp.kernel_blocksparse_exp[triton_grid]
59
- (x,
60
- x_b, x_b_s, x_r_s, x_c_s,
61
- output,
62
- o_b, o_b_s, o_r_s, o_c_s,
63
- triton_block_size))
64
-
65
- ctx.save_for_backward(output)
66
-
67
- return output
68
-
69
- @staticmethod
70
- def backward(ctx, grad_output):
71
- o = ctx.saved_tensors[0]
72
-
73
- grad_x = torch.mul(grad_output, o)
74
-
75
- return grad_x, None, None
76
-
77
- @staticmethod
78
- @triton.jit
79
- def kernel_blocksparse_exp(x,
80
- x_b, x_b_s, x_r_s, x_c_s,
81
- o,
82
- o_b, o_b_s, o_r_s, o_c_s,
83
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
84
- # Get triton block indices
85
- pid_blk = tl.program_id(axis=0)
86
- pid_row = tl.program_id(axis=1)
87
- pid_col = tl.program_id(axis=2)
88
-
89
- # Load block
90
- blk_x_idx = ((pid_blk * x_b_s) +
91
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
92
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
93
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
94
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
95
-
96
- # Compute exp
97
- buf = tl.exp(blk_x)
98
-
99
- # Store block
100
- blk_o_idx = ((pid_blk * o_b_s) +
101
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
102
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
103
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
104
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -1,24 +0,0 @@
1
- blksprs/__init__.py,sha256=wnpk-20jXq7xV0xa-WpHfPQuauI2gEZz9sH-0blKxP0,1766
2
- blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
3
- blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
4
- blksprs/ops/conversion.py,sha256=NK5uXMepPJ9yYh0vnxKwx5_Ffj_bAvhqPVogf_7PY0g,22248
5
- blksprs/ops/distribution.py,sha256=qK5t5XgQSJxXPced8RohprqCtUMMTaEP2pFm3KU1c8o,20267
6
- blksprs/ops/flow.py,sha256=SWHDQ5zx0cjnPR0CcAcRNZdSusSAHSU840SwDNUr24g,6437
7
- blksprs/ops/matmul.py,sha256=LAQyPNwWVmBMRnAex3msLSPD_aG5SblLCMiutJWqvus,11632
8
- blksprs/ops/partitioning.py,sha256=ugKnpvH36ND7qeJQp56M74qqfACkzcTVuXebzw__28Y,8286
9
- blksprs/ops/repeat.py,sha256=RCa-dITomA5v12K5Oxa5_ReA361zS7WHPNNHxSp9PGw,8578
10
- blksprs/ops/softmax.py,sha256=i8NJhvPRYya94AzpN6qiki6_G9KfDrtPifhWd7wbYzk,12496
11
- blksprs/ops/transpose.py,sha256=oAtUu7QzQnNAH3lvRs_MIvIKpBu9h74f9Sk07AxKnDM,6991
12
- blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
13
- blksprs/ops/misc/exp.py,sha256=ygfw7oD6ALdPwNQX_HelKgO8I3-LCgIXH_x0gWzkUN8,3840
14
- blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
15
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
- blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
- blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
18
- blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
19
- blksprs/utils/tools.py,sha256=r7Y4C37vfSWUyQTGwa8NyRqgovmsq9hMufkenqYHOxo,539
20
- blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
21
- blksprs-1.10.1.dist-info/METADATA,sha256=5in6lYCZo1bd8urYR0wkTxIiTTAIAANukLpKeZfGasY,9107
22
- blksprs-1.10.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
23
- blksprs-1.10.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
24
- blksprs-1.10.1.dist-info/RECORD,,